Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR][TRANSFORM] Return value support in tir.tvm_call_packed #7932

Merged
merged 1 commit into from
Apr 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 23 additions & 15 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -334,23 +334,27 @@ TVM_DLL const Op& tvm_stack_make_array();
/*!
* \brief See pesudo code
*
* int tvm_call_packed(name, TVMValue* args) {
* return_type tvm_call_packed(name, TVMValue* args) {
* TVMValue ret_value;
* int ret_code;
* ModuleNode* env = GetCurrentEnv();
* const PackedFunc* f = env->GetFuncFromEnv(name);
* (*f)(args, type_code_of(args), len(args));
* return 0;
* (*f)(args, type_code_of(args), len(args), &ret_value, &ret_code);
* // return type can be int, float, handle.
* return cast(return_type, ret_value.v_return_type);
* }
*/
TVM_DLL const Op& tvm_call_packed();

/*!
* \brief See pesudo code
*
* int tvm_call_trace_packed(name, TVMValue* args) {
* return_type tvm_call_trace_packed(name, TVMValue* args) {
* ModuleNode* env = GetCurrentEnv();
* const PackedFunc* f = env->GetFuncFromEnv(name);
* (*f)(args, type_code_of(args), len(args));
* return 0;
* // return type can be int, float, handle.
* return cast(return_type, ret_value.v_return_type);
* }
*/
TVM_DLL const Op& tvm_call_trace_packed();
Expand All @@ -372,16 +376,18 @@ TVM_DLL const Op& tvm_thread_context();
* \brief Lowered version of call packed, the space of value and
* type codes are explicitly allocated.
*
* int tvm_call_packed_lowered(name,
* TVMValue* value_stack,
* int* tcode_stack,
* int begin,
* int end) {
* return_type tvm_call_packed_lowered(name,
* TVMValue* value_stack,
* int* tcode_stack,
* int begin,
* int end) {
* ModuleNode* env = GetCurrentEnv();
* const PackedFunc* f = env->GetFuncFromEnv(name);
* f->CallPacked(TVMArgs(value_stack[begin:end],
* tcode_stack[begin:end]),
* TVMRetValue(value_stack + end, tcode_stack + end));
* // return type can be int, float, handle.
* return cast(return_type, load_return_from(tcode_stack + end))
* }
*/
TVM_DLL const Op& tvm_call_packed_lowered();
Expand All @@ -391,16 +397,18 @@ TVM_DLL const Op& tvm_call_packed_lowered();
* type codes are explicitly allocated. The return value is the
* (end - 1) value on the stack.
*
* int tvm_call_trace_packed_lowered(name,
* TVMValue* value_stack,
* int* tcode_stack,
* int begin,
* int end) {
* return_type tvm_call_trace_packed_lowered(name,
* TVMValue* value_stack,
* int* tcode_stack,
* int begin,
* int end) {
* ModuleNode* env = GetCurrentEnv();
* const PackedFunc* f = env->GetFuncFromEnv(name);
* f->CallPacked(TVMArgs(value_stack[begin:end],
* tcode_stack[begin:end]),
* TVMRetValue(value_stack + end, tcode_stack + end));
* // return type can be int, float, handle.
* return cast(return_type, load_return_from(tcode_stack + end))
* }
*/
TVM_DLL const Op& tvm_call_trace_packed_lowered();
Expand Down
20 changes: 20 additions & 0 deletions python/tvm/tir/ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,26 @@ def _exit_cb():

return WithScope(None, _exit_cb)

def let(self, var_name, value):
"""Create a new let stmt binding.

Parameters
----------
var_name : str
The name of the variable

value : PrimExpr
The value to be bound

Returns
-------
var : tvm.tir.Var
The var that can be in for future emits.
"""
var = _expr.Var(var_name, dtype=value.dtype)
self.emit(lambda x: _stmt.LetStmt(var, value, x))
return var

def allocate(self, dtype, shape, name="buf", scope=None):
"""Create a allocate statement.

Expand Down
77 changes: 46 additions & 31 deletions src/tir/transforms/lower_tvm_builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,19 @@ class BuiltinLower : public StmtExprMutator {
}

Stmt VisitStmt(const Stmt& s) final {
// allocate space to hold prepare stmts before s
prep_seq_stack_.emplace_back(std::vector<Stmt>());

auto stmt = StmtExprMutator::VisitStmt(s);
auto& scope = alloca_scope_.back();
ICHECK_EQ(scope.run_shape_stack, -1);
ICHECK_EQ(scope.run_array_stack, 0);

if (prep_seq_.size() != 0) {
Stmt ret = SeqStmt::Flatten(prep_seq_, stmt);
prep_seq_.clear();
auto prep_seq = std::move(prep_seq_stack_.back());
prep_seq_stack_.pop_back();

if (prep_seq.size() != 0) {
Stmt ret = SeqStmt::Flatten(prep_seq, stmt);
return ret;
} else {
return stmt;
Expand Down Expand Up @@ -192,6 +197,7 @@ class BuiltinLower : public StmtExprMutator {
// if args.size() == 0, it represents a scalar shape ()
ICHECK(!alloca_scope_.empty());
auto& scope = alloca_scope_.back();
auto& prep_seq = prep_seq_stack_.back();
if (scope.run_shape_stack == -1) {
scope.run_shape_stack = 0;
}
Expand All @@ -201,57 +207,63 @@ class BuiltinLower : public StmtExprMutator {
op = expr.as<CallNode>();
// no need to perform any store for a scalar shape
for (size_t i = 0; i < op->args.size(); ++i) {
prep_seq_.emplace_back(Store(scope.stack_shape, cast(DataType::Int(64), op->args[i]),
ConstInt32(stack_begin + i), const_true(1)));
prep_seq.emplace_back(Store(scope.stack_shape, cast(DataType::Int(64), op->args[i]),
ConstInt32(stack_begin + i), const_true(1)));
}
return AddressOffset(scope.stack_shape, DataType::Int(64), stack_begin);
}
// make array
PrimExpr MakeArray(const CallNode* op) {
ICHECK(!alloca_scope_.empty());
auto& scope = alloca_scope_.back();
auto& prep_seq = prep_seq_stack_.back();

size_t idx = scope.run_array_stack;
scope.run_array_stack += 1;
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<CallNode>();
prep_seq_.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrData, op->args[0]));
prep_seq_.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrShape, op->args[1]));

prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrData, op->args[0]));
prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrShape, op->args[1]));
PrimExpr strides = op->args[2];
if (!strides.defined() || is_zero(strides)) {
strides = make_zero(DataType::Handle());
}
prep_seq_.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrStrides, strides));
prep_seq_.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrNDim, op->args[3]));
prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrStrides, strides));
prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrNDim, op->args[3]));
DataType dtype = op->args[4].dtype();
prep_seq_.emplace_back(
prep_seq.emplace_back(
TVMStructSet(scope.stack_array, idx, builtin::kArrTypeCode,
make_const(DataType::UInt(8), static_cast<int>(dtype.code()))));
prep_seq_.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrTypeBits,
make_const(DataType::UInt(8), dtype.bits())));
prep_seq_.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrTypeLanes,
make_const(DataType::UInt(16), dtype.lanes())));
prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrTypeBits,
make_const(DataType::UInt(8), dtype.bits())));
prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrTypeLanes,
make_const(DataType::UInt(16), dtype.lanes())));
// set byte offset
int data_bytes = GetVectorBytes(dtype);
PrimExpr byte_offset = op->args[5];
if (!is_zero(byte_offset)) {
byte_offset = byte_offset * make_const(byte_offset.dtype(), data_bytes);
}
prep_seq_.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrByteOffset,
cast(DataType::UInt(64), byte_offset)));
prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrByteOffset,
cast(DataType::UInt(64), byte_offset)));
ICHECK(device_type_.defined()) << "Unknown device type in current IR";
ICHECK(device_id_.defined()) << "Unknown device id in current IR";
prep_seq_.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrDeviceId,
cast(DataType::Int(32), device_id_)));
prep_seq_.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrDeviceType,
cast(DataType::Int(32), device_type_)));
prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrDeviceId,
cast(DataType::Int(32), device_id_)));
prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrDeviceType,
cast(DataType::Int(32), device_type_)));
return TVMStructGet(DataType::Handle(), scope.stack_array, idx, builtin::kArrAddr);
}
// call packed.
PrimExpr MakeCallPacked(const CallNode* op) {
auto& scope = alloca_scope_.back();
auto& prep_seq = prep_seq_stack_.back();

int64_t restore_shape_stack = scope.run_shape_stack;
size_t restore_array_stack = scope.run_array_stack;
size_t arg_stack_begin = scope.run_arg_stack;

scope.run_arg_stack += op->args.size();
// Specially handle the buffer packed intrinsic
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
Expand All @@ -264,15 +276,15 @@ class BuiltinLower : public StmtExprMutator {
if (t != api_type) {
arg = Cast(api_type, arg);
}
prep_seq_.emplace_back(TVMStructSet(scope.stack_value,
static_cast<int>(arg_stack_begin + i - 1),
builtin::kTVMValueContent, arg));
prep_seq.emplace_back(TVMStructSet(scope.stack_value,
static_cast<int>(arg_stack_begin + i - 1),
builtin::kTVMValueContent, arg));
int arg_tcode = api_type.code();
if (api_type.is_handle() && arg.as<StringImmNode>()) {
arg_tcode = kTVMStr;
}
if (IsArrayHandle(arg)) arg_tcode = kTVMDLTensorHandle;
prep_seq_.emplace_back(
prep_seq.emplace_back(
Store(scope.stack_tcode, ConstInt32(arg_tcode), stack_index, const_true(1)));
}
// UPDATE stack value
Expand All @@ -285,12 +297,15 @@ class BuiltinLower : public StmtExprMutator {
Array<PrimExpr> packed_args = {op->args[0], scope.stack_value, scope.stack_tcode,
ConstInt32(arg_stack_begin),
ConstInt32(arg_stack_begin + op->args.size() - 1)};
return Call(DataType::Int(32), builtin::tvm_call_packed_lowered(), packed_args);
// call_packed_lowered needs to do the type casting properly
return Call(op->dtype, builtin::tvm_call_packed_lowered(), packed_args);
}

PrimExpr MakeCallTracePacked(const CallNode* op) {
ICHECK(!alloca_scope_.empty());
auto& scope = alloca_scope_.back();
auto& prep_seq = prep_seq_stack_.back();

int64_t restore_shape_stack = scope.run_shape_stack;
size_t restore_array_stack = scope.run_array_stack;
size_t arg_stack_begin = scope.run_arg_stack;
Expand All @@ -307,12 +322,12 @@ class BuiltinLower : public StmtExprMutator {
if (t != api_type) {
arg = Cast(api_type, arg);
}
prep_seq_.emplace_back(TVMStructSet(scope.stack_value,
static_cast<int>(arg_stack_begin + i - 1),
builtin::kTVMValueContent, arg));
prep_seq.emplace_back(TVMStructSet(scope.stack_value,
static_cast<int>(arg_stack_begin + i - 1),
builtin::kTVMValueContent, arg));
int arg_tcode = api_type.code();
ICHECK(!IsArrayHandle(arg)) << "Trace does not support Buffers";
prep_seq_.emplace_back(
prep_seq.emplace_back(
Store(scope.stack_tcode, ConstInt32(arg_tcode), stack_index, const_true(1)));
}
// UPDATE stack value
Expand Down Expand Up @@ -344,8 +359,8 @@ class BuiltinLower : public StmtExprMutator {
return false;
}

// The prepration sequence to be emitted.
std::vector<Stmt> prep_seq_;
// The prepration sequence to be emitted before the current statement.
std::vector<std::vector<Stmt>> prep_seq_stack_;
PrimExpr device_type_;
PrimExpr device_id_;

Expand Down
40 changes: 37 additions & 3 deletions tests/python/unittest/test_tir_transform_lower_tvm_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,45 @@ def check_packed_func(target="llvm"):
tvm.ir.assert_structural_equal(alloca_shape, expected_stmt, map_free_vars=True)


def test_packed_func():
def test_lower_packed_func():
check_packed_func("llvm")
Copy link
Contributor

@giuseros giuseros Apr 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a check also for the C backend? I tried this in AOT (to call _linked_params_lookup) , but the C backend appears to not handling return values correctly.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i believe previously it was only works for non-C backend, but recent PR might added support

Copy link
Contributor

@giuseros giuseros Apr 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the PR that adds non-LLVM support already merged? If so, I think it should be ok to add a test. If not, we might add a TODO and add the test when the PR gets merged

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree a TODO is a good idea

check_packed_func("stackvm")


@tvm.testing.requires_llvm
def test_call_packed_return_non_i32():
# This call packed that return non i32 types
expected_value = np.array([1.2, 1.4], dtype="float32")

def packed_echo(value):
return tvm.tir.call_intrin(
value.dtype, tvm.ir.Op.get("tir.tvm_call_packed"), "testing.echo", value
)

def build_tir():
Ab = tvm.tir.decl_buffer((2,), "float32")
ib = tvm.tir.ir_builder.create()
Aptr = ib.buffer_ptr(Ab)
# return f32
# Aptr[0] = testing.echo(expected_value[0])
Aptr[0] = packed_echo(tvm.tir.const(expected_value[0], "float32"))
# return handle
# let Aptr_var = testing.echo(Aptr) in Aptr_var[1] = expected_value[1]
Aptr_var = ib.let("Aptr_dup", packed_echo(Aptr.asobject()))
ib.emit(tvm.tir.Store(Aptr, tvm.tir.const(expected_value[1], "float32"), 1))

stmt = ib.get()
return tvm.IRModule.from_expr(
tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "packed_test")
)

mod = build_tir()
f = tvm.build(mod, None, "llvm")
a = tvm.nd.array(np.zeros(2, dtype="float32"))
f(a)
tvm.testing.assert_allclose(a.asnumpy(), expected_value)


if __name__ == "__main__":
# Test cases for issue: https:/apache/tvm/issues/7246
test_packed_func()
test_call_packed_return_non_i32()
test_lower_packed_func()