Skip to content

Commit

Permalink
[RELAY][PYTORCH]isNan, isinf, isfinite, ceil, clamp, round ops
Browse files Browse the repository at this point in the history
  • Loading branch information
siju-samuel committed Apr 13, 2020
1 parent 6805d54 commit fa1b859
Show file tree
Hide file tree
Showing 7 changed files with 201 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/frontend/tensorflow.rst
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ Supported Ops
- Identity
- IsFinite
- IsInf
- IsNan
- LeakyRelu
- LeftShift
- Less
Expand Down
54 changes: 53 additions & 1 deletion python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,12 +1118,45 @@ def _impl(inputs, input_types):
return _op.tensor.sqrt(data)
return _impl


def _rsqrt():
def _impl(inputs, input_types):
data = inputs[0]
return _op.tensor.rsqrt(data)
return _impl


def _ceil():
def _impl(inputs, input_types):
data = inputs[0]
return _op.ceil(data)
return _impl


def _clamp():
def _impl(inputs, input_types):
print(inputs, input_types)
data = inputs[0]
amin = inputs[1] if inputs[1] else np.finfo(np.float32).min
amax = inputs[2] if inputs[2] else np.finfo(np.float32).max
return _op.clip(data, amin, amax)
return _impl


def _floor():
def _impl(inputs, input_types):
data = inputs[0]
return _op.floor(data)
return _impl


def _round():
def _impl(inputs, input_types):
data = inputs[0]
return _op.round(data)
return _impl


def _to():
def _impl(inputs, input_types):
data = inputs[0]
Expand Down Expand Up @@ -1232,6 +1265,18 @@ def _impl(inputs, input_types):
return _impl


def _isfinite():
def _impl(inputs, input_types):
return _op.isfinite(inputs[0])
return _impl


def _isnan():
def _impl(inputs, input_types):
return _op.isnan(inputs[0])
return _impl


def _list_getitem(prelude):
def _impl(inputs, input_types):
return prelude.nth(inputs[0], _wrap_const(inputs[1]))
Expand Down Expand Up @@ -1429,7 +1474,11 @@ def _get_convert_map(prelude):
"aten::std" : _std(),
"aten::var" : _variance(),
"aten::sqrt" : _sqrt(),
'aten::floor' : _floor(),
"aten::rsqrt" : _rsqrt(),
"aten::ceil" : _ceil(),
"aten::clamp" : _clamp(),
"aten::floor" : _floor(),
"aten::round" : _round(),
"aten::detach" : _identity(),
"aten::upsample_bilinear2d" : _upsample("bilinear"),
"aten::upsample_nearest2d" : _upsample("nearest_neighbor"),
Expand All @@ -1439,6 +1488,9 @@ def _get_convert_map(prelude):
"aten::le" : _elemwise("less_equal"),
"aten::ge" : _elemwise("greater_equal"),
"aten::ne" : _elemwise("not_equal"),
"aten::eq" : _elemwise("equal"),
"aten::isfinite" : _isfinite(),
"aten::isnan" : _isnan(),
"aten::Bool" : _Bool(),
"aten::Float" : _Float(),
"aten::neg" : _neg(),
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
register_broadcast_schedule("less_equal")
register_broadcast_schedule("greater")
register_broadcast_schedule("greater_equal")
register_broadcast_schedule("isnan")
register_broadcast_schedule("isfinite")
register_broadcast_schedule("isinf")
register_injective_schedule("maximum")
Expand Down
16 changes: 16 additions & 0 deletions python/tvm/relay/op/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,6 +1010,22 @@ def ndarray_size(data, dtype="int32"):
return _make.ndarray_size(data, dtype)


def isnan(data):
"""Check nan in input data element-wise.
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.isnan(data)


def isfinite(data):
"""Compute element-wise finiteness of data.
Expand Down
11 changes: 10 additions & 1 deletion src/relay/op/tensor/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,15 @@ ElemwiseArbitraryLayout)
.set_support_level(10)
.set_attr<FTVMCompute>("FTVMCompute", NdarraySizeCompute);

RELAY_REGISTER_UNARY_OP("isnan")
.describe(R"code(Returns whether the input contains any NaN, computed element-wise.
.. math::
isnan(x)
)code" TVM_ADD_FILELINE)
.set_support_level(3)
.add_type_rel("IdentityCompRel", IdentityCompRel)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::isnan));

RELAY_REGISTER_UNARY_OP("isfinite")
.describe(R"code(Returns the finiteness of input, computed element-wise.
.. math::
Expand All @@ -438,7 +447,7 @@ RELAY_REGISTER_UNARY_OP("isfinite")
RELAY_REGISTER_UNARY_OP("isinf")
.describe(R"code(Returns the infiniteness of input, computed element-wise.
.. math::
isfinite(x)
isinf(x)
)code" TVM_ADD_FILELINE)
.set_support_level(3)
.add_type_rel("IdentityCompRel", IdentityCompRel)
Expand Down
8 changes: 8 additions & 0 deletions src/target/intrin_rule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,14 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sigmoid")
*rv = one / (one + exp(-call->args[0]));
});

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.nan")
.set_body([](const TVMArgs& args, TVMRetValue* rv){
PrimExpr e = args[0];
const CallNode* call = e.as<CallNode>();
CHECK(call != nullptr);
*rv = isnan(call->args[0]);
});

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.isfinite")
.set_body([](const TVMArgs& args, TVMRetValue* rv){
PrimExpr e = args[0];
Expand Down
112 changes: 112 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1441,6 +1441,110 @@ def forward(self, *args):
verify_model(Variance5().float().eval(), input_data=input_data)



def test_forward_isfinite():
torch.set_grad_enabled(False)

class IsFinite1(Module):
def forward(self, *args):
return torch.isfinite(args[0])

input_data = torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]).float()
verify_model(IsFinite1().float().eval(), input_data=input_data)


def test_forward_isnan():
torch.set_grad_enabled(False)

class IsNan1(Module):
def forward(self, *args):
return torch.isnan(args[0])

input_data = torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]).float()
verify_model(IsNan1().float().eval(), input_data=input_data)


def test_forward_isinf():
torch.set_grad_enabled(False)

class IsInf1(Module):
def forward(self, *args):
return torch.isinf(args[0])

input_data = torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]).float()
verify_model(IsInf1().float().eval(), input_data=input_data)


def test_forward_rsqrt():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]

class Rsqrt1(Module):
def forward(self, *args):
return torch.rsqrt(args[0])

input_data = torch.rand(input_shape).float()
verify_model(Rsqrt1().float().eval(), input_data=input_data)


def test_forward_ceil():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]

class Ceil1(Module):
def forward(self, *args):
return torch.ceil(args[0])

input_data = torch.rand(input_shape).float()
verify_model(Ceil1().float().eval(), input_data=input_data)


def test_forward_clamp():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]

class Clamp1(Module):
def forward(self, *args):
return torch.clamp(args[0], min=-0.5, max=0.5)

class Clamp2(Module):
def forward(self, *args):
return torch.clamp(args[0], min=-0.3)

class Clamp3(Module):
def forward(self, *args):
return torch.clamp(args[0], max=1.0)

input_data = torch.rand(input_shape).float()
verify_model(Clamp1().float().eval(), input_data=input_data)
verify_model(Clamp2().float().eval(), input_data=input_data)
verify_model(Clamp3().float().eval(), input_data=input_data)


def test_forward_floor():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]

class Floor1(Module):
def forward(self, *args):
return torch.floor(args[0])

input_data = torch.rand(input_shape).float()
verify_model(Floor1().float().eval(), input_data=input_data)


def test_forward_round():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]

class Round1(Module):
def forward(self, *args):
return torch.round(args[0])

input_data = torch.rand(input_shape).float()
verify_model(Round1().float().eval(), input_data=input_data)


if __name__ == "__main__":
# Single operator tests
test_forward_add()
Expand Down Expand Up @@ -1497,6 +1601,14 @@ def forward(self, *args):
test_forward_expand()
test_forward_pow()
test_forward_abs()
test_forward_rsqrt()
test_forward_ceil()
test_forward_clamp()
test_forward_floor()
test_forward_round()
test_forward_isfinite()
test_forward_isnan()
test_forward_isinf()
test_forward_arange()
test_forward_chunk()
test_forward_split()
Expand Down

0 comments on commit fa1b859

Please sign in to comment.