Skip to content

Commit

Permalink
[NNVM] Error handling message (apache#14)
Browse files Browse the repository at this point in the history
* [NNVM] Error handling message
  • Loading branch information
wweic authored Mar 12, 2019
1 parent 717b019 commit cdbabd6
Showing 1 changed file with 32 additions and 16 deletions.
48 changes: 32 additions & 16 deletions nnvm/python/nnvm/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def _impl(attr):
kernel = attr['kernel_shape']
if len(kernel) == 2:
return prefix + '2d' + surfix
raise NotImplementedError("Only 2d kernel supported.")
raise NotImplementedError("OperatorAttributeNotImplemented : "
"Only 2d kernel supported in {}.".format(prefix))
return _impl

def _dimension_constraint():
Expand Down Expand Up @@ -103,6 +104,7 @@ def _impl(inputs, attr, params):
axis_input_vlaue = params[axis_input_name].asnumpy()[0]
except (IndexError, KeyError):
raise TypeError( \
"OperatorAttributeValueNotValid : "
"Unsupported argument for `{}` : `axis` should be a constant".format(func_name))
return func(inputs[0], axis=axis_input_vlaue, keepdims=False)
return _impl
Expand All @@ -129,7 +131,9 @@ def _impl(inputs, attr, params):
attr['kernel_shape'] = (attr['ksize'][2], attr['ksize'][3])
attr['strides'] = (attr['strides'][2], attr['strides'][3])
else:
raise TypeError("Unsupported data_format type : {}".format(attr['data_format']))
raise TypeError("OperatorAttributeValueNotValid : "
"Unsupported data_format type : {} in {}"\
.format(attr['data_format'], name))

if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC":
tmp_shape = attr['_input_shapes'][inputs[0]]
Expand Down Expand Up @@ -158,7 +162,9 @@ def _impl(inputs, attr, params):

attr['padding'] = [pad_v[0], pad_h[0], pad_v[1], pad_h[1]]
else:
raise TypeError("Unsupported padding type : {}".format(attr['padding']))
raise TypeError("OperatorAttributeValueNotValid : "
"Unsupported padding type : {} in {}"\
.format(attr['padding'], name))

if name == "avg_pool":
attr['count_include_pad'] = False
Expand Down Expand Up @@ -232,8 +238,9 @@ def _impl(inputs, attr, params):
attr['dilations'] = (attr['dilations'][2], attr['dilations'][3])
attr['strides'] = (attr['strides'][2], attr['strides'][3])
else:
raise TypeError("Unsupported data format type : {}".format(attr['data_format']))

raise TypeError("OperatorAttributeValueNotValid : "
"Unsupported data format type : {} in {}"\
.format(attr['data_format'], opname))

if opname == 'depthwise':
attr['groups'] = attr['channels']
Expand Down Expand Up @@ -276,7 +283,9 @@ def _impl(inputs, attr, params):
attr['padding'] = [0, 0]

else:
raise TypeError("Unsupported padding type : {}".format(attr['padding']))
raise TypeError("OperatorAttributeValueNotValid : "
"Unsupported padding type : {} in {}"\
.format(attr['padding'], opname))

if 'kernel_layout' not in attr:
if opname == 'conv':
Expand Down Expand Up @@ -440,8 +449,8 @@ def _impl(inputs, attr, params):
lhs = inputs[0]
rhs = inputs[1]
if "_input_shapes" in attr:
lhs_shape = attr["_input_shapes"][lhs][0]
rhs_shape = attr["_input_shapes"][rhs][0]
lhs_shape = attr["_input_shapes"][lhs]
rhs_shape = attr["_input_shapes"][rhs]
if len(lhs_shape) == 4 and len(rhs_shape) == 1 and lhs_shape[3] == rhs_shape[0]:
# FIXME(yizhiliu): workaround for TF models on TRT. @re:invent 2018
# bias_add(NHWC, C), expand_dim C to [1, 1, C],
Expand Down Expand Up @@ -747,7 +756,9 @@ def _impl(inputs, attr, params):
if padlist_key in params:
padlist = params.pop(padlist_key).asnumpy()
else:
raise RuntimeError("Required parameter {} not fount.".format(padlist_key))
raise RuntimeError("OperatorAttributeValueNotValid : "
"Required parameter {} not found in {}."\
.format(padlist_key, name))
paddings = tuple([tuple(l) for l in padlist])
attr['pad_width'] = paddings
attr['pad_value'] = 0
Expand Down Expand Up @@ -963,6 +974,8 @@ def _impl(inputs, attr, params):
'Split' : _split(False),
'SplitV' : _split(True),
'Unpack' : _unpack(),
'QueueDequeueManyV2' : _undef(),
'FIFOQueueV2' : _undef(),
}

# _convert_map_rnn defines maps of rnn operator name to
Expand Down Expand Up @@ -1167,7 +1180,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
Follow the tensorflow graph definition to parse and convert it to NNVM.
Some of the assumptions listed below.
-> All Placeholders are considered as graph input.
-> All Placeholders/PlaceholderWithDefaults are considered as graph input.
-> All Const nodes are params.
-> Last node is assumed as graph output.
-> _output_shapes : Graph should be frozen with add_shapes=True.
Expand Down Expand Up @@ -1206,10 +1219,11 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):

if missing_operators:
raise NotImplementedError( \
"OperatorNotImplemented : "
"The following operators are not implemented: {}".format(missing_operators))

for node in graph.node:
if node.op == 'Placeholder':
if node.op == 'Placeholder' or node.op == 'PlaceholderWithDefault':
if shape and node.name in shape:
self._input_shapes[node.name] = list(shape[node.name])
continue
Expand Down Expand Up @@ -1246,7 +1260,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
elif shape and node.name in shape:
# Give priority to user argument.
self._output_shapes[node.name] = [shape[node.name]]
elif node.op == 'Placeholder':
elif node.op == 'Placeholder' or node.op == 'PlaceholderWithDefault':
self._output_shapes[node.name] = [self._input_shapes[node.name]]
elif '_output_shapes' in attr:
self._output_shapes[node.name] = \
Expand All @@ -1262,7 +1276,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
not tshape if isinstance(tshape, list) else False \
for tshape in self._output_shapes[node.name]]

if node.op == "Placeholder":
if node.op == "Placeholder" or node.op == 'PlaceholderWithDefault':
self._nodes[node.name] = _sym.Variable(name=node.name,
shape=self._input_shapes[node.name])

Expand Down Expand Up @@ -1379,7 +1393,7 @@ def _parse_import_prerequisites(self, graph):
"""
missing_operators = set()
for node in graph.node:
if node.op == "Placeholder":
if node.op == "Placeholder" or node.op == "PlaceholderWithDefault":
pass
elif node.op == "Const":
pass
Expand Down Expand Up @@ -1419,7 +1433,8 @@ def _parse_param(self, key, value, name):
else:
if key not in ('dtype', '_output_shapes', '_class'):
raise NotImplementedError \
("Other attributes for a Const(param) Node {} ? .".format(key))
("OperatorAttributeNotImplemented : "
"Other attributes for a Const(param) Node {} ? .".format(key))

def _get_attr(self, buf):
"""Returns the value of the attr of this buf with the given `name`.
Expand Down Expand Up @@ -1546,7 +1561,8 @@ def _convert_operator(self, op_name, inputs, attrs,
self._params, graph,
convert_map_rnn)
else:
raise NotImplementedError("Operator {} not implemented.".format(op_name))
raise NotImplementedError("OperatorNotImplemented : "
"Operator {} not implemented.".format(op_name))
return sym

def _fix_extranodes(self, op_name, attr, inputs):
Expand Down

0 comments on commit cdbabd6

Please sign in to comment.