Skip to content
This repository has been archived by the owner on Oct 13, 2021. It is now read-only.

Commit

Permalink
Update the behavior on custom op. (#459)
Browse files Browse the repository at this point in the history
  • Loading branch information
wenbingl authored Apr 26, 2020
1 parent 996bf2d commit cd9918a
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 9 deletions.
2 changes: 1 addition & 1 deletion keras2onnx/_parser_1x.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from .common import k2o_logger
from .funcbook import get_converter
from ._parse_tf import infer_variable_type, tsname_to_node, adjust_input_batch_size
from ._parser_tf import infer_variable_type, tsname_to_node, adjust_input_batch_size


def extract_inbound_nodes(layer):
Expand Down
File renamed without changes.
5 changes: 5 additions & 0 deletions keras2onnx/_tf_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,16 @@ def convert_tf_random_standard_normal(scope, operator, container):


def pass_thru_converter(scope, operator, container):
"""
This converter is to copy the original graph node with its def into a ONNX node format.
"""
tf_op = operator.raw_operator
attrs = _to_onnx_attrs(tf_op)

container.add_node(operator.type,
operator.input_full_names,
operator.output_full_names,
name=operator.full_name,
op_domain='ai.onnx.contrib',
op_version=1,
**attrs)
3 changes: 2 additions & 1 deletion keras2onnx/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
from .topology import Topology
from .common.utils import set_logger_level, k2o_logger
from .funcbook import set_converter
from ._tf_utils import tsname_to_node
from ._builtin import register_direct_tf_ops
from ._parser_1x import build_opdict_from_keras
from ._parse_tf import tsname_to_node, build_layer_output_from_model
from ._parser_tf import build_layer_output_from_model


def convert_keras(model, name=None, doc_string='', target_opset=None,
Expand Down
14 changes: 7 additions & 7 deletions keras2onnx/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from .funcbook import get_converter, set_converter
from ._consts import TYPES
from ._tf_ops import pass_thru_converter
from ._parse_tf import (infer_variable_type, LayerInfo, is_placeholder_node,
tsname_to_node, on_parsing_keras_layer_v2, adjust_input_batch_size as _adjust_input_batch_size)
from ._parser_tf import (infer_variable_type, LayerInfo, is_placeholder_node,
tsname_to_node, on_parsing_keras_layer_v2, adjust_input_batch_size as _adjust_input_batch_size)
from ._parser_1x import (extract_inbound_nodes,
list_input_tensors, list_input_mask, list_output_mask,
list_output_tensors, list_input_shapes, list_output_shapes, on_parsing_keras_layer)
Expand Down Expand Up @@ -294,15 +294,15 @@ def _check_tfnode_converter_availability(graph, node):


def _check_tfnodes_converter_availability(graph, nodelist, debug_mode):
status = True
for n_ in nodelist:
if not _check_tfnode_converter_availability(graph, n_):
k2o_logger().warning(
"The tf.op node {} of type {} cannot be converted".format(n_.name, n_.type))
if debug_mode:
continue
return False
"WARN: No corresponding ONNX op matches the tf.op node {} of type {}".format(n_.name, n_.type) +
"\n The generated ONNX model needs run with the custom op supports.")
status = False

return True
return status


def _on_parsing_tf_nodes(graph, nodelist, varset, debug_mode):
Expand Down

0 comments on commit cd9918a

Please sign in to comment.