diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index 9868fa699e76..2ad364b96c3f 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -196,22 +196,26 @@ class Binding : public ObjectRef { class MatchShape; class MatchShapeNode : public BindingNode { public: - Array pattern; Expr value; + Array pattern; + Var var; void VisitAttrs(AttrVisitor* v) { - v->Visit("pattern", &pattern); v->Visit("value", &value); + v->Visit("pattern", &pattern); + v->Visit("var", &var); v->Visit("span", &span); } bool SEqualReduce(const MatchShapeNode* other, SEqualReducer equal) const { - return equal(pattern, other->pattern) && equal(value, other->value); + return equal(value, other->value) && equal(pattern, other->pattern) + && equal(var, other->var); } void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(pattern); hash_reduce(value); + hash_reduce(pattern); + hash_reduce(var); } static constexpr const char* _type_key = "relax.expr.MatchShape"; @@ -222,7 +226,8 @@ class MatchShapeNode : public BindingNode { class MatchShape : public Binding { public: - TVM_DLL explicit MatchShape(Array pattern, Expr value, Span span = Span()); + TVM_DLL explicit MatchShape(Expr value, Array pattern, + Var var, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(MatchShape, Binding, MatchShapeNode); }; diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index d4f78feef070..1cc067d73458 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -103,11 +103,12 @@ def __init__(self, span: Span = None) -> None: @tvm._ffi.register_object("relax.expr.MatchShape") class MatchShape(Binding): - pattern: List[PrimExpr] value: Expr + pattern: List[PrimExpr] + var: Var - def __init__(self, pattern: List[PrimExpr], value: Expr, span: Span = None) -> None: - self.__init_handle_by_constructor__(_ffi_api.MatchShape, pattern, value, span) + def __init__(self, value: Expr, pattern: List[PrimExpr], var: Var, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.MatchShape, value, pattern, var, span) @tvm._ffi.register_object("relax.expr.VarBinding") diff --git a/src/relax/expr.cc b/src/relax/expr.cc index 08e1a01f1d94..13445f1788b1 100644 --- a/src/relax/expr.cc +++ b/src/relax/expr.cc @@ -65,10 +65,10 @@ Var::Var(Id vid, Optional shape_annotation, Optional type_annotation } TVM_REGISTER_GLOBAL("relax.Var") - .set_body_typed([](String name_hint, Optional shape_annotation, - Optional type_annotation, Span span) { - return Var(name_hint, shape_annotation, type_annotation, span); - }); +.set_body_typed([](String name_hint, Optional shape_annotation, + Optional type_annotation, Span span) { + return Var(name_hint, shape_annotation, type_annotation, span); +}); TVM_REGISTER_NODE_TYPE(DataflowVarNode); @@ -83,10 +83,10 @@ DataflowVar::DataflowVar(Id vid, Optional shape_annotation, Optional } TVM_REGISTER_GLOBAL("relax.DataflowVar") - .set_body_typed([](String name_hint, Optional shape_annotation, - Optional type_annotation, Span span) { - return DataflowVar(name_hint, shape_annotation, type_annotation, span); - }); +.set_body_typed([](String name_hint, Optional shape_annotation, + Optional type_annotation, Span span) { + return DataflowVar(name_hint, shape_annotation, type_annotation, span); +}); Binding::Binding(Span span) { ObjectPtr n = make_object(); @@ -96,22 +96,25 @@ Binding::Binding(Span span) { TVM_REGISTER_NODE_TYPE(BindingNode); -TVM_REGISTER_GLOBAL("relax.Binding").set_body_typed([](Span span) { return Binding(span); }); +TVM_REGISTER_GLOBAL("relax.Binding").set_body_typed([](Span span) { + return Binding(span); +}); TVM_REGISTER_NODE_TYPE(MatchShapeNode); -MatchShape::MatchShape(Array pattern, Expr value, Span span) { +MatchShape::MatchShape(Expr value, Array pattern, Var var, Span span) { ObjectPtr n = make_object(); - n->pattern = std::move(pattern); n->value = std::move(value); + n->pattern = std::move(pattern); + n->var = std::move(var); n->span = span; data_ = std::move(n); } TVM_REGISTER_GLOBAL("relax.MatchShape") - .set_body_typed([](Array pattern, Expr value, Span span) { - return MatchShape(pattern, value, span); - }); +.set_body_typed([](Expr value, Array pattern, Var var, Span span) { + return MatchShape(value, pattern, var, span); +}); TVM_REGISTER_NODE_TYPE(VarBindingNode); @@ -182,9 +185,10 @@ Function::Function(runtime::Optional name, Array params, Expr bo } TVM_REGISTER_GLOBAL("relax.Function") - .set_body_typed([](runtime::Optional name, Array params, Expr body, - Type ret_type, - Span span) { return Function(name, params, body, ret_type, span); }); +.set_body_typed([](runtime::Optional name, Array params, + Expr body, Type ret_type, Span span) { + return Function(name, params, body, ret_type, span); +}); TVM_REGISTER_NODE_TYPE(ExternFuncNode); diff --git a/tests/python/relax/test_expr.py b/tests/python/relax/test_expr.py index 60e8bd340f5b..f515d331f9ea 100644 --- a/tests/python/relax/test_expr.py +++ b/tests/python/relax/test_expr.py @@ -1,7 +1,6 @@ import tvm from tvm import tir from tvm import relax as rx -from tvm.ir import TensorType import numpy as np @@ -11,7 +10,7 @@ def test_var() -> None: assert v0.shape_ is None assert v0.type_annotation is None shape_anno = [54, 96] - type_anno = TensorType(shape_anno, "float32") + type_anno = rx.DynTensorType(2, "float32") v1 = rx.Var("v1", shape_anno, type_anno) assert v1.name_hint == "v1" for s0, s1 in zip(v1.shape_, shape_anno): @@ -25,7 +24,7 @@ def test_dataflow_var() -> None: assert v0.shape_ is None assert v0.type_annotation is None shape_anno = [54, 96] - type_anno = TensorType(shape_anno, "float16") + type_anno = rx.DynTensorType(2, "float16") v1 = rx.DataflowVar("v1", shape_anno, type_anno) assert v1.name_hint == "v1" for s0, s1 in zip(v1.shape_, shape_anno): @@ -35,13 +34,34 @@ def test_dataflow_var() -> None: def test_match_shape() -> None: + # match_shape([16, 8], [m, n]) m = tir.Var("m", dtype="int32") n = tir.Var("n", dtype="int32") shape = rx.const([16, 8], "int32") - b0 = rx.MatchShape([m, n], shape) + var = rx.Var("v0", type_annotation=rx.ShapeType()) + b0 = rx.MatchShape(shape, [m, n], var) + assert b0.value == shape assert b0.pattern[0] == m assert b0.pattern[1] == n - assert b0.value == shape + assert b0.var is not None + assert b0.var.checked_type_ == rx.ShapeType() + + # var1: Tensor[(m, n), "float32"] = + # match_shape(var0: Tensor[_, "float32"], [m, n]) + type_anno0 = rx.DynTensorType(-1, "float32") + value = rx.Var("value", type_annotation=type_anno0) + + shape_anno = [m, n] + type_anno = rx.DynTensorType(2, "float32") + var = rx.Var("v1", shape_anno, type_anno) + b1 = rx.MatchShape(value, [m, n], var) + assert b1.value == value + assert b1.pattern[0] == m + assert b1.pattern[1] == n + assert b1.var is not None + for s0, s1 in zip(b1.var.shape, [m, n]): + assert s0 == s1 + assert b1.var.checked_type_ == rx.DynTensorType(2, "float32") def test_var_binding() -> None: @@ -56,7 +76,7 @@ def test_binding_block() -> None: m = tir.Var("m", dtype="int32") n = tir.Var("n", dtype="int32") shape = rx.const([16, 8], "int32") - b0 = rx.MatchShape([m, n], shape) + b0 = rx.MatchShape(shape, [m, n], rx.Var("v0")) v0 = rx.Var("v0") val = rx.const(np.random.rand(24, 56)) @@ -71,7 +91,7 @@ def test_dataflow_block() -> None: m = tir.Var("m", dtype="int32") n = tir.Var("n", dtype="int32") shape = rx.const([16, 8], "int32") - b0 = rx.MatchShape([m, n], shape) + b0 = rx.MatchShape(shape, [m, n], rx.Var("v0")) v0 = rx.Var("v0") val = rx.const(np.random.rand(24, 56)) @@ -105,7 +125,7 @@ def test_func(): bindings = [rx.VarBinding(x, rx.const(1))] blocks = [rx.BindingBlock(bindings)] seqe = rx.SeqExpr(blocks, x) - ret_type = TensorType(None, "float32") + ret_type = rx.DynTensorType(-1, "float32") func = rx.Function([x], seqe, ret_type, rx.GlobalVar("func")) assert func.params[0] == x assert func.body == seqe