diff --git a/src/script/builder/builder.cc b/src/script/builder/builder.cc index 93bea0288434..c04a1b5f346a 100644 --- a/src/script/builder/builder.cc +++ b/src/script/builder/builder.cc @@ -45,6 +45,7 @@ Builder Builder::Current() { } TVM_REGISTER_NODE_TYPE(BuilderNode); +TVM_REGISTER_NODE_TYPE(FrameNode); } // namespace builder } // namespace script diff --git a/src/script/builder/builder.h b/src/script/builder/builder.h index 03c682ab6261..7357223dd64e 100644 --- a/src/script/builder/builder.h +++ b/src/script/builder/builder.h @@ -21,12 +21,39 @@ #include -#include "./frame.h" - namespace tvm { namespace script { namespace builder { +class FrameNode : public runtime::Object { + public: + std::vector> callbacks; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `callbacks` is not visited. + } + + void AddCallback(runtime::TypedPackedFunc callback) { callbacks.push_back(callback); } + + static constexpr const char* _type_key = "script.builder.Frame"; + TVM_DECLARE_BASE_OBJECT_INFO(FrameNode, runtime::Object); + + public: + virtual ~FrameNode() { + for (auto it = callbacks.rbegin(); it != callbacks.rend(); ++it) { + (*it)(); + } + } +}; + +class Frame : public runtime::ObjectRef { + public: + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Frame, ObjectRef, FrameNode); + + protected: + Frame() = default; +}; + class BuilderNode : public runtime::Object { public: runtime::Array frames; @@ -35,8 +62,21 @@ class BuilderNode : public runtime::Object { v->Visit("frames", &frames); // } - static constexpr const char* _type_key = "script.Builder"; - TVM_DECLARE_BASE_OBJECT_INFO(BuilderNode, runtime::Object); + static constexpr const char* _type_key = "script.builder.Builder"; + TVM_DECLARE_FINAL_OBJECT_INFO(BuilderNode, runtime::Object); + + public: + template + TFrame FindFrame() const { + using TFrameNode = typename TFrame::ContainerType; + for (auto it = frames.rbegin(); it != frames.rend(); ++it) { + if (const TFrameNode* p = (*it).template as()) { + return GetRef(p); + } + } + LOG(FATAL) << "IndexError: Cannot find frame: " << TFrameNode::_type_key; + throw; + } }; class Builder : public runtime::ObjectRef { diff --git a/src/script/builder/tir/for_frame.cc b/src/script/builder/tir/for_frame.cc new file mode 100644 index 000000000000..e17191036120 --- /dev/null +++ b/src/script/builder/tir/for_frame.cc @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./for_frame.h" + +namespace tvm { +namespace script { +namespace builder { +namespace tir { + +ForFrame::ForFrame(Array loop_vars, ForFrame::FMakeForLoop f_make_for_loop) { + ObjectPtr n = make_object(); + n->loop_vars = std::move(loop_vars); + n->f_make_for_loop = std::move(f_make_for_loop); + data_ = std::move(n); +} + +#define TVM_SCRIPT_BUILDER_TIR_FOR_CREATE(Method, Kind) \ + With Method(PrimExpr min, PrimExpr extent, Map attrs) { \ + ObjectPtr n = make_object(); \ + int bits = std::max(min.dtype().bits(), extent.dtype().bits()); \ + n->loop_vars = {tvm::tir::Var("v", DataType::Int(bits))}; \ + n->f_make_for_loop = [=](Array vars, tvm::tir::Stmt body) -> tvm::tir::For { \ + ICHECK_EQ(vars.size(), 1); \ + return tvm::tir::For(/*loop_var=*/vars[0], min, extent, Kind, body, \ + /*thread_binding=*/NullOpt, attrs); \ + }; \ + return With(n); \ + } + +TVM_SCRIPT_BUILDER_TIR_FOR_CREATE(Serial, tvm::tir::ForKind::kSerial); +TVM_SCRIPT_BUILDER_TIR_FOR_CREATE(Parallel, tvm::tir::ForKind::kParallel); +TVM_SCRIPT_BUILDER_TIR_FOR_CREATE(Vectorized, tvm::tir::ForKind::kVectorized); +TVM_SCRIPT_BUILDER_TIR_FOR_CREATE(Unroll, tvm::tir::ForKind::kUnrolled); + +#undef TVM_SCRIPT_BUILDER_TIR_FOR_CREATE + +With ThreadBinding(PrimExpr min, PrimExpr extent, String thread, + Map attrs) { + using namespace tvm::tir; + ObjectPtr n = make_object(); + int bits = std::max(min.dtype().bits(), extent.dtype().bits()); + n->loop_vars = {Var("v", DataType::Int(bits))}; + n->f_make_for_loop = [=](Array vars, Stmt body) -> For { + ICHECK_EQ(vars.size(), 1); + IterVar iter_var(Range(nullptr), Var(ObjectPtr(nullptr)), IterVarType::kThreadIndex, + thread); + return For(vars[0], min, extent, tvm::tir::ForKind::kThreadBinding, body, iter_var, attrs); + }; + return With(n); +} + +With Grid(Array extents) { + using namespace tvm::tir; + ObjectPtr n = make_object(); + n->loop_vars.reserve(extents.size()); + for (const auto& extent : extents) { + n->loop_vars.push_back(Var("v", extent.dtype())); + } + n->f_make_for_loop = [=](Array vars, Stmt body) -> Stmt { + ICHECK_EQ(extents.size(), vars.size()); + int n = extents.size(); + for (int i = n - 1; i >= 0; --i) { + Var var = vars[i]; + PrimExpr extent = extents[i]; + body = For(var, Integer(0), extent, ForKind::kSerial, body, /*thread_binding=*/NullOpt, {}); + } + return body; + }; + return With(n); +} + +TVM_REGISTER_NODE_TYPE(ForFrameNode); + +} // namespace tir +} // namespace builder +} // namespace script +} // namespace tvm diff --git a/src/script/builder/tir/for_frame.h b/src/script/builder/tir/for_frame.h new file mode 100644 index 000000000000..b4e634905eed --- /dev/null +++ b/src/script/builder/tir/for_frame.h @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_BUILDER_TIR_FOR_FRAME_H_ +#define TVM_SCRIPT_BUILDER_TIR_FOR_FRAME_H_ + +#include +#include +#include +#include + +#include "./tir.h" + +namespace tvm { +namespace script { +namespace builder { +namespace tir { + +class ForFrameNode : public TIRFrameNode { + public: + using FMakeForLoop = + runtime::TypedPackedFunc, tvm::tir::Stmt)>; + + Array loop_vars; + FMakeForLoop f_make_for_loop; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("loop_vars", &loop_vars); + // `f_make_for_loop` is not visited. + } + + static constexpr const char* _type_key = "script.builder.tir.ForFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(ForFrameNode, TIRFrameNode); +}; + +class ForFrame : public TIRFrame { + public: + using FMakeForLoop = ForFrameNode::FMakeForLoop; + + explicit ForFrame(Array loop_vars, FMakeForLoop f_make_for_loop); + + void EnterWithScope() { ICHECK(data_ != nullptr); } + + void ExitWithScope() { + ICHECK(data_ != nullptr); + data_.reset(); + } + + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ForFrame, TIRFrame, ForFrameNode); +}; + +With Serial(PrimExpr min, PrimExpr extent, Map annotations); +With Parallel(PrimExpr min, PrimExpr extent, Map annotations); +With Vectorized(PrimExpr min, PrimExpr extent, Map annotations); +With Unroll(PrimExpr min, PrimExpr extent, Map annotations); +With ThreadBinding(PrimExpr min, PrimExpr extent, String thread, + Map annotations); +With Grid(Array extents); + +} // namespace tir +} // namespace builder +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_BUILDER_TIR_FOR_FRAME_H_ diff --git a/src/script/builder/tir/prim_func_frame.cc b/src/script/builder/tir/prim_func_frame.cc new file mode 100644 index 000000000000..3736f692de53 --- /dev/null +++ b/src/script/builder/tir/prim_func_frame.cc @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "./prim_func_frame.h" + +namespace tvm { +namespace script { +namespace builder { +namespace tir { + +void Arg(tvm::tir::Var var) { + PrimFuncFrame frame = Builder::Current()->FindFrame(); + frame->args.push_back(var); +} + +void Arg(tvm::tir::Buffer buffer) { + using namespace tvm::tir; + PrimFuncFrame frame = Builder::Current()->FindFrame(); + Var handle(buffer->name + "_handle", DataType::Handle()); + frame->args.push_back(handle); + frame->buffer_map.Set(handle, buffer); +} + +TVM_REGISTER_NODE_TYPE(PrimFuncFrameNode); + +} // namespace tir +} // namespace builder +} // namespace script +} // namespace tvm diff --git a/src/script/builder/tir/prim_func_frame.h b/src/script/builder/tir/prim_func_frame.h new file mode 100644 index 000000000000..0be38a587347 --- /dev/null +++ b/src/script/builder/tir/prim_func_frame.h @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_BUILDER_TIR_PRIM_FUNC_FRAME_H_ +#define TVM_SCRIPT_BUILDER_TIR_PRIM_FUNC_FRAME_H_ + +#include "./tir.h" + +namespace tvm { +namespace script { +namespace builder { +namespace tir { + +class PrimFuncFrameNode : public TIRFrameNode { + public: + String name; + Array args; + Type ret_type; + Map buffer_map; + + void VisitAttrs(tvm::AttrVisitor* v) { + TIRFrameNode::VisitAttrs(v); + v->Visit("name", &name); + v->Visit("args", &args); + v->Visit("ret_type", &ret_type); + v->Visit("buffer_map", &buffer_map); + } + + static constexpr const char* _type_key = "script.builder.tir.PrimFuncFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncFrameNode, TIRFrameNode); +}; + +class PrimFuncFrame : public TIRFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PrimFuncFrame, TIRFrame, PrimFuncFrameNode); +}; + +void Arg(tvm::tir::Var var); +void Arg(tvm::tir::Buffer buffer); + +} // namespace tir +} // namespace builder +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_BUILDER_TIR_PRIM_FUNC_FRAME_H_ diff --git a/src/script/builder/frame.cc b/src/script/builder/tir/tir.cc similarity index 90% rename from src/script/builder/frame.cc rename to src/script/builder/tir/tir.cc index b81ce415e229..faef2372909e 100644 --- a/src/script/builder/frame.cc +++ b/src/script/builder/tir/tir.cc @@ -16,14 +16,16 @@ * specific language governing permissions and limitations * under the License. */ -#include "./frame.h" +#include "./tir.h" namespace tvm { namespace script { namespace builder { +namespace tir { -TVM_REGISTER_NODE_TYPE(FrameNode); +TVM_REGISTER_NODE_TYPE(TIRFrameNode); +} // namespace tir } // namespace builder } // namespace script } // namespace tvm diff --git a/src/script/builder/frame.h b/src/script/builder/tir/tir.h similarity index 57% rename from src/script/builder/frame.h rename to src/script/builder/tir/tir.h index 5d10920f2744..d5638413b871 100644 --- a/src/script/builder/frame.h +++ b/src/script/builder/tir/tir.h @@ -16,43 +16,42 @@ * specific language governing permissions and limitations * under the License. */ -#ifndef TVM_SCRIPT_BUILDER_FRAME_H_ -#define TVM_SCRIPT_BUILDER_FRAME_H_ +#ifndef TVM_SCRIPT_BUILDER_TIR_TIR_H_ +#define TVM_SCRIPT_BUILDER_TIR_TIR_H_ -#include +#include + +#include "../builder.h" namespace tvm { namespace script { namespace builder { +namespace tir { -class FrameNode : public runtime::Object { +class TIRFrameNode : public FrameNode { public: - std::vector> callbacks; + Array stmts; void VisitAttrs(tvm::AttrVisitor* v) { - // `callbacks` is not visited. + FrameNode::VisitAttrs(v); + v->Visit("stmts", &stmts); } - void AddCallback(runtime::TypedPackedFunc callback) { callbacks.push_back(callback); } - - static constexpr const char* _type_key = "script.Frame"; - TVM_DECLARE_BASE_OBJECT_INFO(FrameNode, runtime::Object); - - public: - virtual ~FrameNode() { - for (auto it = callbacks.rbegin(); it != callbacks.rend(); ++it) { - (*it)(); - } - } + static constexpr const char* _type_key = "script.builder.tir.TIRFrame"; + TVM_DECLARE_BASE_OBJECT_INFO(TIRFrameNode, FrameNode); }; -class Frame : public runtime::ObjectRef { +class TIRFrame : public Frame { public: - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Frame, ObjectRef, FrameNode); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TIRFrame, Frame, TIRFrameNode); + + protected: + TIRFrame() = default; }; +} // namespace tir } // namespace builder } // namespace script } // namespace tvm -#endif // TVM_SCRIPT_BUILDER_FRAME_H_ +#endif // TVM_SCRIPT_BUILDER_TIR_TIR_H_