Skip to content

Commit

Permalink
[SparseTIR] Parser, Printer, Roundtrip (apache#14)
Browse files Browse the repository at this point in the history
* SparseBlock scope handler (part 1)

* SparseBlock scope handler (part 2)

* SparseBlock scope handler (part 3)

* SparseBlock scope handler (fix 1)

* Add SparseBufferLoad/Store on Python side

* Parser for SparseBufferLoad/Store

* Add SparseBlock to Python __init__

* StmtFunctor for SparseBlock

* Ensure at least one dimension for SparseBuffer

* Make `axis` field of SpIterVar mandatory

* SparseBlock scope handler (fix 2)

* Update Axis syntax by removing `name` parameter

* Move to intrin.py

* Add filed `from_sparse` to DenseFixedAxis

* SparseTIR script printer

* Roundtrip test

* `update_symbol` bug fix

* Fix attr visit in SparseBuffer

* Define then compare in SparseBlock

* Fix printer bug for SparseBuffer

* Enable graph match for Axis and SparseBuffer

* Complete HashReduce and EqualReduce for AxisTree and SparseBuffer

* Fix typo

* Rename test

* Bug fix 1

* Bug fix 2

* Add more tests
  • Loading branch information
MasterJH5574 committed Nov 10, 2021
1 parent d19f493 commit 64c1103
Show file tree
Hide file tree
Showing 20 changed files with 783 additions and 234 deletions.
2 changes: 1 addition & 1 deletion include/tvm/tir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,7 @@ class BufferLoad : public PrimExpr {
*/
class SparseBufferLoadNode : public PrimExprNode {
public:
/*! \brief The buffer variable. */
/*! \brief The buffer to be loaded. */
SparseBuffer buffer;
/*! \brief The indices location to be loaded. */
Array<PrimExpr> indices;
Expand Down
88 changes: 60 additions & 28 deletions include/tvm/tir/sparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,23 +99,48 @@ class DenseAxis : public Axis {
TVM_DEFINE_OBJECT_REF_METHODS(DenseAxis, Axis, DenseAxisNode);
};

/*!
* \brief Sparse axis whose column indices is not consecutive.
*/
class SparseAxisNode : public AxisNode {
public:
static constexpr const char* _type_key = "tir.sparse.SparseAxis";
TVM_DECLARE_BASE_OBJECT_INFO(SparseAxisNode, AxisNode);
};

/*!
* \brief Managed reference to SparseAxisNode.
* \sa SparseAxisNode
*/
class SparseAxis : public Axis {
public:
TVM_DEFINE_OBJECT_REF_METHODS(SparseAxis, Axis, SparseAxisNode);
};

/*!
* \brief Dense axis with fixed length per row.
*/
class DenseFixedAxisNode : public DenseAxisNode {
public:
Optional<SparseAxis> from_sparse;

void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("length", &length);
v->Visit("from_sparse", &from_sparse);
}

bool SEqualReduce(const DenseAxisNode* other, SEqualReducer equal) const {
return equal(name, other->name) && equal(length, other->length);
bool SEqualReduce(const DenseFixedAxisNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return equal(name, other->name) && equal(length, other->length) &&
equal(from_sparse, other->from_sparse);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce->MarkGraphNode();
hash_reduce(name);
hash_reduce(length);
hash_reduce(from_sparse);
}

static constexpr const char* _type_key = "tir.sparse.DenseFixedAxis";
Expand All @@ -128,7 +153,8 @@ class DenseFixedAxisNode : public DenseAxisNode {
*/
class DenseFixedAxis : public DenseAxis {
public:
TVM_DLL explicit DenseFixedAxis(String name, PrimExpr length);
TVM_DLL explicit DenseFixedAxis(String name, PrimExpr length,
Optional<SparseAxis> from_sparse = NullOpt);

TVM_DEFINE_OBJECT_REF_METHODS(DenseFixedAxis, DenseAxis, DenseFixedAxisNode);
};
Expand All @@ -144,10 +170,12 @@ class DenseVariableAxisNode : public DenseAxisNode {
}

bool SEqualReduce(const DenseVariableAxisNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return equal(name, other->name) && equal(length, other->length) && equal(indptr, other->indptr);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce->MarkGraphNode();
hash_reduce(name);
hash_reduce(length);
hash_reduce(indptr);
Expand All @@ -168,24 +196,6 @@ class DenseVariableAxis : public DenseAxis {
TVM_DEFINE_OBJECT_REF_METHODS(DenseVariableAxis, DenseAxis, DenseVariableAxisNode);
};

/*!
* \brief Sparse axis whose column indices is not consecutive.
*/
class SparseAxisNode : public AxisNode {
public:
static constexpr const char* _type_key = "tir.sparse.SparseAxis";
TVM_DECLARE_BASE_OBJECT_INFO(SparseAxisNode, AxisNode);
};

/*!
* \brief Managed reference to SparseAxisNode.
* \sa SparseAxisNode
*/
class SparseAxis : public Axis {
public:
TVM_DEFINE_OBJECT_REF_METHODS(SparseAxis, Axis, SparseAxisNode);
};

/*!
* \brief Sparse axis with fixed number of non-zero columns per row.
*/
Expand All @@ -203,11 +213,13 @@ class SparseFixedAxisNode : public SparseAxisNode {
}

bool SEqualReduce(const SparseFixedAxisNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return equal(name, other->name) && equal(length, other->length) &&
equal(indices, other->indices) && equal(num_cols, other->num_cols);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce->MarkGraphNode();
hash_reduce(name);
hash_reduce(length);
hash_reduce(indices);
Expand Down Expand Up @@ -245,11 +257,13 @@ class SparseVariableAxisNode : public SparseAxisNode {
}

bool SEqualReduce(const SparseVariableAxisNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return equal(name, other->name) && equal(length, other->length) &&
equal(indptr, other->indptr) && equal(indices, other->indices);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce->MarkGraphNode();
hash_reduce(name);
hash_reduce(length);
hash_reduce(indptr);
Expand Down Expand Up @@ -277,13 +291,27 @@ class SparseVariableAxis : public SparseAxis {
class AxisTreeNode : public Object {
public:
// unordered map that stores the parent relationship between axes.
std::unordered_map<String, Optional<String>, ObjectPtrHash, ObjectPtrEqual> parent;
Map<String, Optional<String>> parent;
// unordered map that stores the children relationship between axes.
std::unordered_map<Optional<String>, Array<String>, ObjectPtrHash, ObjectPtrEqual> children;
Map<Optional<String>, Array<String>> children;

void VisitAttrs(AttrVisitor* v) {
v->Visit("parent", &parent);
v->Visit("children", &children);
}

bool SEqualReduce(const AxisTreeNode* other, SEqualReducer equal) const {
return equal(parent, other->parent) && equal(children, other->children);
}

void VisitAttrs(AttrVisitor* v) {}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(parent);
hash_reduce(children);
}

static constexpr const char* _type_key = "tir.sparse.AxisTree";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(AxisTreeNode, Object);
};

Expand Down Expand Up @@ -313,22 +341,26 @@ class SparseBufferNode : public Object {
inline int ndim() const { return static_cast<int>(axes.size()); }

void VisitAttrs(AttrVisitor* v) {
v->Visit("length", &axes);
v->Visit("num_cols", &data);
v->Visit("axes", &axes);
v->Visit("data", &data);
v->Visit("name", &name);
}

bool SEqualReduce(const SparseBufferNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return equal(axes, other->axes) && equal(data, other->data) && equal(name, other->name);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce->MarkGraphNode();
hash_reduce(axes);
hash_reduce(data);
hash_reduce(name);
}

static constexpr const char* _type_key = "tir.sparse.SparseBuffer";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(SparseBufferNode, Object);
};

Expand Down Expand Up @@ -359,7 +391,7 @@ class SpIterVarNode : public Object {
PrimExpr max_extent;
SpIterKind kind;
bool is_reduction;
Optional<Axis> axis;
Axis axis;

void VisitAttrs(AttrVisitor* v) {
v->Visit("var", &var);
Expand Down Expand Up @@ -392,7 +424,7 @@ class SpIterVarNode : public Object {
class SpIterVar : public ObjectRef {
public:
TVM_DLL explicit SpIterVar(Var var, PrimExpr max_extent, SpIterKind kind, bool is_reduction,
Optional<Axis> axis = NullOpt);
Axis axis);

/*!
* \return the corresponding var in the IterVar.
Expand Down
12 changes: 6 additions & 6 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -335,11 +335,11 @@ class BufferStore : public Stmt {
* buffer[i, j] = value;
*
* \endcode
* \sa SparseBufferLoad
* \sa SparseBufferStore
*/
class SparseBufferStoreNode : public StmtNode {
public:
/*! \brief The buffer variable. */
/*! \brief The sparse buffer to be accessed. */
SparseBuffer buffer;
/*! \brief The value to be stored. */
PrimExpr value;
Expand Down Expand Up @@ -1303,17 +1303,17 @@ class SparseBlockNode : public StmtNode {
}

bool SEqualReduce(const SparseBlockNode* other, SEqualReducer equal) const {
return equal(sp_iter_vars, other->sp_iter_vars) &&
equal(sp_struct2param_map, other->sp_struct2param_map) && equal(name, other->name) &&
equal(body, other->body) && equal(init, other->init);
return equal(sp_iter_vars, other->sp_iter_vars) && equal(name, other->name) &&
equal(body, other->body) && equal(init, other->init) &&
equal(sp_struct2param_map, other->sp_struct2param_map);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(sp_iter_vars);
hash_reduce(sp_struct2param_map);
hash_reduce(name);
hash_reduce(body);
hash_reduce(init);
hash_reduce(sp_struct2param_map);
}

static constexpr const char* _type_key = "tir.SparseBlock";
Expand Down
4 changes: 4 additions & 0 deletions include/tvm/tir/stmt_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
virtual R VisitStmt_(const EvaluateNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const BlockNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const BlockRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const SparseBlockNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmtDefault_(const Object* op, Args...) {
LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
return R();
Expand Down Expand Up @@ -126,6 +127,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
IR_STMT_FUNCTOR_DISPATCH(BufferRealizeNode);
IR_STMT_FUNCTOR_DISPATCH(BlockNode);
IR_STMT_FUNCTOR_DISPATCH(BlockRealizeNode);
IR_STMT_FUNCTOR_DISPATCH(SparseBlockNode);
return vtable;
}
};
Expand Down Expand Up @@ -169,6 +171,7 @@ class TVM_DLL StmtVisitor : protected StmtFunctor<void(const Stmt&)> {
void VisitStmt_(const EvaluateNode* op) override;
void VisitStmt_(const BlockNode* op) override;
void VisitStmt_(const BlockRealizeNode* op) override;
void VisitStmt_(const SparseBlockNode* op) override;
};

/*!
Expand Down Expand Up @@ -270,6 +273,7 @@ class TVM_DLL StmtMutator : protected StmtFunctor<Stmt(const Stmt&)> {
Stmt VisitStmt_(const EvaluateNode* op) override;
Stmt VisitStmt_(const BlockNode* op) override;
Stmt VisitStmt_(const BlockRealizeNode* op) override;
Stmt VisitStmt_(const SparseBlockNode* op) override;
/*!
* \brief Alternative advance method for SeqStmtNode.
*
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/script/context_maintainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def update_symbol(
self, name: str, symbol: Union[Buffer, Var, SparseBuffer, Axis], node: synr.ast.Node
):
"""Append a symbol into current scope"""
if isinstance(symbol, (Buffer, Var, SparseBuffer, Axis)):
if isinstance(symbol, (Buffer, SparseBuffer, Axis)):
if name in self.symbols[0]:
self.report_error("Duplicate Buffer name: " + symbol.name, node.span)
self.symbols[0][name] = symbol
Expand Down
10 changes: 10 additions & 0 deletions python/tvm/script/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,14 @@ def transform_SubscriptAssign(self, node):
indexes,
span=tvm_span_from_synr(node.span),
)
elif isinstance(symbol, tvm.tir.sparse.SparseBuffer):
# SparseBufferStore
return tvm.tir.SparseBufferStore(
symbol,
tvm.runtime.convert(rhs, span=rhs_span),
indexes,
span=tvm_span_from_synr(node.span),
)
else:
if len(indexes) != 1:
self.report_error(
Expand Down Expand Up @@ -881,6 +889,8 @@ def transform_Subscript(self, node):
return BufferSlice(
symbol, indexes, self.report_error, span=tvm_span_from_synr(node.span)
)
elif isinstance(symbol, tvm.tir.sparse.SparseBuffer):
return tvm.tir.SparseBufferLoad(symbol, indexes, span=tvm_span_from_synr(node.span))
elif isinstance(symbol, tvm.container.Array):
if len(indexes) > 1:
self.report_error(
Expand Down
43 changes: 42 additions & 1 deletion python/tvm/script/tir/intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,18 @@
"""TVM Script Parser Intrinsic Classes"""
# pylint: disable=redefined-builtin, relative-beyond-top-level
import builtins
from typing import List, Any
from typing import List, Optional, Any

import tvm.tir
from tvm.ir import Span
from tvm.tir.sparse import (
Axis,
DenseFixedAxis,
DenseVariableAxis,
SpIterVar,
SparseFixedAxis,
SparseVariableAxis,
)
from ..registry import register
from ..utils import get_param_list, tvm_span_from_synr

Expand Down Expand Up @@ -244,3 +253,35 @@ def comm_reducer(lambda_io, identities, span):
lambda_output = (lambda_output,)

return tvm.tir.CommReducer(x, y, lambda_output, identities, span)


@register
def to_dense(axis: Axis, span: Optional[Span] = None):
if isinstance(axis, (SparseFixedAxis, SparseVariableAxis)):
return DenseFixedAxis(axis.name + "_dense", axis.length, axis)
else:
return axis


@register
def cord(axis: Axis, span: Optional[Span] = None):
# The field `var` and `is_reduction` will be updated in SparseBlock scope handler
var_temp = tvm.te.var()
if isinstance(axis, DenseVariableAxis):
return SpIterVar(var_temp, axis.length, SpIterVar.DenseVariable, False, axis)
else:
return SpIterVar(var_temp, axis.length, SpIterVar.DenseFixed, False, axis)


@register
def pos(axis: Axis, span: Optional[Span] = None):
# The field `var` and `is_reduction` will be updated in SparseBlock scope handler
var_temp = tvm.te.var()
if isinstance(axis, DenseFixedAxis):
return SpIterVar(var_temp, axis.length, SpIterVar.DenseFixed, False, axis)
elif isinstance(axis, DenseVariableAxis):
return SpIterVar(var_temp, axis.length, SpIterVar.DenseVariable, False, axis)
elif isinstance(axis, SparseFixedAxis):
return SpIterVar(var_temp, axis.length, SpIterVar.SparseFixed, False, axis)
else:
return SpIterVar(var_temp, axis.length, SpIterVar.SparseVariable, False, axis)
Loading

0 comments on commit 64c1103

Please sign in to comment.