Skip to content

Commit

Permalink
[VM] Minor refactor for C++ memory alloc (#7413)
Browse files Browse the repository at this point in the history
* started moving things to header

* directly call InvokeTVMOp

* done all memory op

* also refactor AllocTensor

* declare Prod

* remove cached func for Add, Multiply, Divide

* lint fix

* revert test change

* remove tensor.h and declare Prod in pattern_utils.h
  • Loading branch information
masahi authored Feb 6, 2021
1 parent fc08430 commit 1f846f0
Show file tree
Hide file tree
Showing 9 changed files with 200 additions and 166 deletions.
73 changes: 0 additions & 73 deletions src/relay/op/device_copy.cc

This file was deleted.

83 changes: 61 additions & 22 deletions src/relay/op/memory/memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,22 @@
* \brief Operators for manifest shape-aware memory allocation in Relay.
*/

#include "memory.h"

#include <tvm/node/node.h>
#include <tvm/relay/attrs/memory.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/runtime/data_type.h>
#include <tvm/topi/elemwise.h>

#include <vector>

#include "../../transforms/infer_layout_utils.h"
#include "../op_common.h"
#include "../type_relations.h"
#include "tvm/relay/attrs/device_copy.h"

namespace tvm {
namespace relay {
Expand All @@ -42,15 +48,16 @@ TVM_REGISTER_NODE_TYPE(AllocTensorAttrs);
// The passing value in attrs and args doesn't seem super great.
// We should consider a better solution, i.e the type relation
// being able to see the arguments as well?
TVM_REGISTER_GLOBAL("relay.op.memory._make.alloc_storage")
.set_body_typed([](Expr size, Expr alignment, TVMContext ctx, DataType dtype_hint) {
auto attrs = make_object<AllocStorageAttrs>();
attrs->dtype = dtype_hint;
attrs->device_id = ctx.device_id;
attrs->device_type = ctx.device_type;
static const Op& op = Op::Get("memory.alloc_storage");
return Call(op, {size, alignment}, Attrs(attrs), {});
});
Expr AllocStorage(Expr size, Expr alignment, TVMContext ctx, DataType dtype_hint) {
auto attrs = make_object<AllocStorageAttrs>();
attrs->dtype = dtype_hint;
attrs->device_id = ctx.device_id;
attrs->device_type = ctx.device_type;
static const Op& op = Op::Get("memory.alloc_storage");
return Call(op, {size, alignment}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op.memory._make.alloc_storage").set_body_typed(AllocStorage);

bool AllocStorageRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
Expand Down Expand Up @@ -90,19 +97,20 @@ RELAY_REGISTER_OP("memory.alloc_storage")
return {topi::identity(inputs[0])};
});

TVM_REGISTER_GLOBAL("relay.op.memory._make.alloc_tensor")
.set_body_typed([](Expr storage, Expr offset, tvm::relay::Expr shape, DataType dtype,
Array<IndexExpr> assert_shape) {
auto attrs = make_object<AllocTensorAttrs>();
attrs->dtype = dtype;
if (assert_shape.defined()) {
attrs->assert_shape = assert_shape;
} else {
attrs->const_shape = Downcast<Constant>(shape);
}
static const Op& op = Op::Get("memory.alloc_tensor");
return Call(op, {storage, offset, shape}, Attrs(attrs), {});
});
Expr AllocTensor(Expr storage, Expr offset, tvm::relay::Expr shape, DataType dtype,
Array<IndexExpr> assert_shape) {
auto attrs = make_object<AllocTensorAttrs>();
attrs->dtype = dtype;
if (assert_shape.defined()) {
attrs->assert_shape = assert_shape;
} else {
attrs->const_shape = Downcast<Constant>(shape);
}
static const Op& op = Op::Get("memory.alloc_tensor");
return Call(op, {storage, offset, shape}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op.memory._make.alloc_tensor").set_body_typed(AllocTensor);

std::vector<int64_t> FromConstShape(Constant konst) {
runtime::NDArray shape = konst->data;
Expand Down Expand Up @@ -299,5 +307,36 @@ TVM_REGISTER_GLOBAL("relay.op.memory._make.ToTupleType")
return ToTupleType(t, std::vector<Expr>(array.begin(), array.end()));
});

// relay.device_copy
TVM_REGISTER_NODE_TYPE(DeviceCopyAttrs);

Expr DeviceCopy(Expr data, int src_dev_type, int dst_dev_type) {
auto attrs = make_object<DeviceCopyAttrs>();
attrs->src_dev_type = src_dev_type;
attrs->dst_dev_type = dst_dev_type;
static const Op& op = Op::Get("device_copy");
return Call(op, {data}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op._make.device_copy").set_body_typed(DeviceCopy);

RELAY_REGISTER_OP("device_copy")
.describe(R"code(
Copy data from one tensor to another. The source and destination might be
on different devices.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input data.")
.set_support_level(10)
.add_type_rel("Identity", IdentityRel)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_dtype) -> Array<te::Tensor> {
return {topi::identity(inputs[0])};
});

} // namespace relay
} // namespace tvm
46 changes: 46 additions & 0 deletions src/relay/op/memory/memory.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file src/relay/op/memory/memory.h
* \brief Operators for memory related operations in Relay.
*/

#ifndef TVM_RELAY_OP_MEMORY_MEMORY_H_
#define TVM_RELAY_OP_MEMORY_MEMORY_H_

#include <vector>

#include "tvm/relay/expr.h"

namespace tvm {
namespace relay {

Expr AllocStorage(Expr size, Expr alignment, TVMContext ctx, DataType dtype_hint);
Expr DeviceCopy(Expr data, int src_dev_type, int dst_dev_type);
Expr AllocTensor(Expr storage, Expr offset, tvm::relay::Expr shape, DataType dtype,
Array<IndexExpr> assert_shape);
Expr ToTupleType(const Type& ty, const std::vector<Expr>& exprs);
std::vector<Expr> FromTupleType(const Type& type, const Expr& expr);
std::vector<TensorType> FlattenTupleType(const Type& type);

} // namespace relay
} // namespace tvm

#endif // TVM_RELAY_OP_MEMORY_MEMORY_H_
6 changes: 5 additions & 1 deletion src/relay/op/tensor/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,11 @@ Array<te::Tensor> ProdCompute(const Attrs& attrs, const Array<te::Tensor>& input
return ReduceCompute(attrs, inputs, out_type, topi::prod);
}

RELAY_REGISTER_REDUCE_OP("prod")
TVM_REGISTER_GLOBAL("relay.op._make.prod").set_body_typed(Prod);

RELAY_REGISTER_OP("prod")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.describe(R"code(Computes the products of array elements over given axes.
Example::
Expand Down
49 changes: 29 additions & 20 deletions src/relay/op/vm/vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
* \brief Dialect operators for Relay VM.
*/

#include "vm.h"

#include <tvm/relay/attrs/memory.h>
#include <tvm/relay/attrs/vm.h>
#include <tvm/relay/expr.h>
Expand All @@ -30,6 +32,8 @@
#include <tvm/runtime/data_type.h>
#include <tvm/topi/elemwise.h>

#include <utility>

#include "../../transforms/infer_layout_utils.h"
#include "../op_common.h"
#include "../type_relations.h"
Expand All @@ -52,20 +56,23 @@ RELAY_REGISTER_OP("vm.shape_of")
.set_attr<TNonComputational>("TNonComputational", true)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);

TVM_REGISTER_GLOBAL("relay.op.vm.shape_of").set_body_typed([](Expr expr) {
Expr ShapeOf(Expr expr) {
auto attrs = make_object<ShapeOfAttrs>();
attrs->dtype = DataType::Int(64);
static const Op& op = Op::Get("vm.shape_of");
return Call(op, {expr}, Attrs(attrs), {});
});
}

TVM_REGISTER_GLOBAL("relay.op.vm.shape_of").set_body_typed(ShapeOf);

Expr ShapeFunc(Expr func, Expr inputs, Expr outputs, Array<tvm::Integer> is_input) {
static const Op& op = Op::Get("vm.shape_func");
auto attrs = make_object<ShapeFuncAttrs>();
attrs->is_input = is_input;
return Call(op, {func, inputs, outputs}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op.vm.shape_func")
.set_body_typed([](Expr func, Expr inputs, Expr outputs, Array<tvm::Integer> is_input) {
static const Op& op = Op::Get("vm.shape_func");
auto attrs = make_object<ShapeFuncAttrs>();
attrs->is_input = is_input;
return Call(op, {func, inputs, outputs}, Attrs(attrs), {});
});
TVM_REGISTER_GLOBAL("relay.op.vm.shape_func").set_body_typed(ShapeFunc);

bool ShapeFuncRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
Expand Down Expand Up @@ -162,10 +169,11 @@ bool InvokeTVMOpRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
return true;
}

TVM_REGISTER_GLOBAL("relay.op.vm.invoke_tvm_op")
.set_body_typed([](Expr func, Expr inputs, Expr outputs) {
return Call(Op::Get("vm.invoke_tvm_op"), {func, inputs, outputs}, Attrs());
});
Expr InvokeTVMOp(Expr func, Expr inputs, Expr outputs) {
return Call(Op::Get("vm.invoke_tvm_op"), {func, inputs, outputs}, Attrs());
}

TVM_REGISTER_GLOBAL("relay.op.vm.invoke_tvm_op").set_body_typed(InvokeTVMOp);

RELAY_REGISTER_OP("vm.invoke_tvm_op")
.describe(R"code(Invoke an operation compiled by TVM.)code" TVM_ADD_FILELINE)
Expand Down Expand Up @@ -212,13 +220,14 @@ RELAY_REGISTER_OP("vm.reshape_tensor")
.set_attr<TNonComputational>("TNonComputational", true)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);

TVM_REGISTER_GLOBAL("relay.op.vm.reshape_tensor")
.set_body_typed([](Expr data, Expr shape, Array<PrimExpr> newshape) {
static const Op& op = Op::Get("vm.reshape_tensor");
auto attrs = make_object<ReshapeTensorAttrs>();
attrs->newshape = std::move(newshape);
return Call(op, {data, shape}, Attrs(attrs), {});
});
Expr ReshapeTensor(Expr data, Expr shape, Array<PrimExpr> newshape) {
static const Op& op = Op::Get("vm.reshape_tensor");
auto attrs = make_object<ReshapeTensorAttrs>();
attrs->newshape = std::move(newshape);
return Call(op, {data, shape}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op.vm.reshape_tensor").set_body_typed(ReshapeTensor);

} // namespace relay
} // namespace tvm
40 changes: 40 additions & 0 deletions src/relay/op/vm/vm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file src/relay/op/vm/vm.h
* \brief Dialect operators for Relay VM.
*/
#ifndef TVM_RELAY_OP_VM_VM_H_
#define TVM_RELAY_OP_VM_VM_H_

#include "tvm/relay/expr.h"

namespace tvm {
namespace relay {

Expr InvokeTVMOp(Expr func, Expr inputs, Expr outputs);
Expr ShapeFunc(Expr func, Expr inputs, Expr outputs, Array<tvm::Integer> is_input);
Expr ShapeOf(Expr expr);
Expr ReshapeTensor(Expr data, Expr shape, Array<PrimExpr> newshape);

} // namespace relay
} // namespace tvm

#endif // TVM_RELAY_OP_VM_VM_H_
Loading

0 comments on commit 1f846f0

Please sign in to comment.