Skip to content

Commit

Permalink
fix naming convention
Browse files Browse the repository at this point in the history
  • Loading branch information
icemelon committed Jan 16, 2017
1 parent 1c1ba2a commit 7988c28
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 10 deletions.
17 changes: 15 additions & 2 deletions include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,21 @@
namespace tvm {
namespace ir {

using Halide::Internal::equal;
using Halide::Internal::simplify;
inline bool Equal(Expr a, Expr b) {
return Halide::Internal::equal(a, b);
}

inline bool Equal(Stmt a, Stmt b) {
return Halide::Internal::equal(a, b);
}

inline Expr Simplify(Expr a) {
return Halide::Internal::simplify(a);
}

inline Stmt Simplify(Stmt a) {
return Halide::Internal::simplify(a);
}

/*!
* \brief Schedule s' dependent operations.
Expand Down
10 changes: 5 additions & 5 deletions src/c_api/c_api_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,20 @@ TVM_REGISTER_API(_pass_Simplify)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
if (dynamic_cast<Expr::ContainerType*>(args.at(0).sptr.get())) {
*ret = simplify(args.at(0).operator Expr());
*ret = Simplify(args.at(0).operator Expr());
} else {
*ret = simplify(args.at(0).operator Stmt());
*ret = Simplify(args.at(0).operator Stmt());
}
});

TVM_REGISTER_API(_pass_equal)
TVM_REGISTER_API(_pass_Equal)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
CHECK(args.at(1).type_id == kNodeHandle);
if (dynamic_cast<Expr::ContainerType*>(args.at(0).sptr.get())) {
*ret = equal(args.at(0).operator Expr(), args.at(1).operator Expr());
*ret = Equal(args.at(0).operator Expr(), args.at(1).operator Expr());
} else {
*ret = equal(args.at(0).operator Stmt(), args.at(1).operator Stmt());
*ret = Equal(args.at(0).operator Stmt(), args.at(1).operator Stmt());
}
});

Expand Down
6 changes: 3 additions & 3 deletions tests/python/test_pass_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
def test_simplify():
x = tvm.Var('x')
e1 = tvm.ir_pass.Simplify(x + 2 + 1)
assert(tvm.ir_pass.equal(e1, x + 3))
assert(tvm.ir_pass.Equal(e1, x + 3))
e2 = tvm.ir_pass.Simplify(x * 3 + 5 * x)
assert(tvm.ir_pass.equal(e2, x * 8))
assert(tvm.ir_pass.Equal(e2, x * 8))
e3 = tvm.ir_pass.Simplify(x - x / 3 * 3)
assert(tvm.ir_pass.equal(e3, tvm.make.Mod(x, 3)))
assert(tvm.ir_pass.Equal(e3, tvm.make.Mod(x, 3)))


def test_verify_ssa():
Expand Down

0 comments on commit 7988c28

Please sign in to comment.