Skip to content

Commit

Permalink
Rename FnValue into Closure (apache#33)
Browse files Browse the repository at this point in the history
  • Loading branch information
jroesch committed Aug 16, 2018
1 parent 1440d3c commit 5675bf2
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 25 deletions.
14 changes: 7 additions & 7 deletions relay/include/tvm/relay/ir/value.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,28 +90,28 @@ class FloatValueNode : public ValueNode {

RELAY_DEFINE_VALUE(FloatValue, FloatValueNode);

class FnValue;
class Closure;

/*! \brief A floating point value. */
class FnValueNode : public ValueNode {
class ClosureNode : public ValueNode {
public:
tvm::Map<LocalId, Value> env;
Function func;

FnValueNode() {}
ClosureNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("env", &env);
v->Visit("func", &func);
}

TVM_DLL static FnValue make(tvm::Map<LocalId, Value> env, Function func);
TVM_DLL static Closure make(tvm::Map<LocalId, Value> env, Function func);

static constexpr const char* _type_key = "relay.FnValue";
TVM_DECLARE_NODE_TYPE_INFO(FnValueNode, ValueNode);
static constexpr const char* _type_key = "relay.Closure";
TVM_DECLARE_NODE_TYPE_INFO(ClosureNode, ValueNode);
};

RELAY_DEFINE_VALUE(FnValue, FnValueNode);
RELAY_DEFINE_VALUE(Closure, ClosureNode);

class ProductValue;

Expand Down
2 changes: 1 addition & 1 deletion relay/python/relay/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
IntValue = value.IntValue
FloatValue = value.FloatValue
BoolValue = value.BoolValue
FnValue = value.FnValue
Closure = value.Closure
TensorValue = value.TensorValue

# Type
Expand Down
2 changes: 1 addition & 1 deletion relay/python/relay/ir/value.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class BoolValue(Value):


@register_relay_node
class FnValue(Value):
class Closure(Value):
env: Dict[LocalId, Value]
func: Function

Expand Down
14 changes: 7 additions & 7 deletions relay/src/tvm/relay/evaluator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ struct Evaluator : ExprFunctor<Value(const Expr& n)> {
return this->stack.lookup(local);
}

Value invoke(const FnValue& func, tvm::Array<Value>& args);
Value invoke(const Closure& func, tvm::Array<Value>& args);
Value invoke_intrinsic(const IntrinsicId& op, tvm::Array<Value>& args,
const tvm::Array<Type>& arg_types);
Value Eval(const Expr& expr);
Expand Down Expand Up @@ -264,7 +264,7 @@ Value Evaluator::VisitExpr_(const FunctionNode* op) {
for (const auto& id : free_id) {
free_var_values.Set(id, Eval(id));
}
return FnValueNode::make(free_var_values, GetRef<Function>(op));
return ClosureNode::make(free_var_values, GetRef<Function>(op));
}

Value Evaluator::invoke_intrinsic(const IntrinsicId& op,
Expand Down Expand Up @@ -335,7 +335,7 @@ Value Evaluator::invoke_intrinsic(const IntrinsicId& op,

// An efficient interpreter needs a faster way to access args, relative to stack
// pointer?
Value Evaluator::invoke(const FnValue& closure, tvm::Array<Value>& args) {
Value Evaluator::invoke(const Closure& closure, tvm::Array<Value>& args) {
// In the VM we should support building a frame from free vars and parameters
// we should compute the frame layout statically.

Expand Down Expand Up @@ -401,8 +401,8 @@ Value Evaluator::VisitExpr_(const CallNode* op) {
return this->invoke_intrinsic(GetRef<IntrinsicId>(intr), args, op->ty_args);
} else {
auto fn_val = this->VisitExpr(op->fn);
if (const FnValueNode* closure = fn_val.as<FnValueNode>()) {
return this->invoke(GetRef<FnValue>(closure), args);
if (const ClosureNode* closure = fn_val.as<ClosureNode>()) {
return this->invoke(GetRef<Closure>(closure), args);
} else {
throw EvalError(
"Type error, expected function value in the call position");
Expand Down Expand Up @@ -567,8 +567,8 @@ TVM_REGISTER_API("relay.eval.invoke")
// type check the call before execution.
Evaluator eval(env);
auto fn_val = eval.VisitExpr(id);
if (const FnValueNode* closure = fn_val.as<FnValueNode>()) {
*ret = eval.invoke(GetRef<FnValue>(closure), relay_args);
if (const ClosureNode* closure = fn_val.as<ClosureNode>()) {
*ret = eval.invoke(GetRef<Closure>(closure), relay_args);
} else {
throw EvalError(
"Type error, expected function value in the call position");
Expand Down
14 changes: 7 additions & 7 deletions relay/src/tvm/relay/ir/value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,21 +63,21 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
p->stream << "FloatValueNode(" << node->value << ")";
});

FnValue FnValueNode::make(tvm::Map<LocalId, Value> env, Function func) {
std::shared_ptr<FnValueNode> n = std::make_shared<FnValueNode>();
Closure ClosureNode::make(tvm::Map<LocalId, Value> env, Function func) {
std::shared_ptr<ClosureNode> n = std::make_shared<ClosureNode>();
n->env = std::move(env);
n->func = std::move(func);
return FnValue(n);
return Closure(n);
}

TVM_REGISTER_API("relay.make.FnValue")
TVM_REGISTER_API("relay.make.Closure")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = FnValueNode::make(args[0], args[1]);
*ret = ClosureNode::make(args[0], args[1]);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<FnValueNode>([](const FnValueNode *node, tvm::IRPrinter *p) {
p->stream << "FnValueNode(todo)";
.set_dispatch<ClosureNode>([](const ClosureNode *node, tvm::IRPrinter *p) {
p->stream << "ClosureNode(todo)";
});

ProductValue ProductValueNode::make(tvm::Array<Value> value) {
Expand Down
4 changes: 2 additions & 2 deletions relay/tests/python/test_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,10 @@ def test_float_value():
assert fv.value == 1.135


def test_fn_value():
def test_closure():
params = [Param(LocalId("x"), IntType(2))]
fn = Function(params, IntType(2), LocalId("x"))
fn_val = FnValue({}, fn)
fn_val = Closure({}, fn)
assert fn_val.func == fn


Expand Down

0 comments on commit 5675bf2

Please sign in to comment.