Skip to content
This repository has been archived by the owner on Oct 13, 2021. It is now read-only.

Update the behavior on custom op. #459

Merged
merged 2 commits into from
Apr 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion keras2onnx/_parser_1x.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from .common import k2o_logger
from .funcbook import get_converter
from ._parse_tf import infer_variable_type, tsname_to_node, adjust_input_batch_size
from ._parser_tf import infer_variable_type, tsname_to_node, adjust_input_batch_size


def extract_inbound_nodes(layer):
Expand Down
File renamed without changes.
5 changes: 5 additions & 0 deletions keras2onnx/_tf_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,16 @@ def convert_tf_random_standard_normal(scope, operator, container):


def pass_thru_converter(scope, operator, container):
"""
This converter is to copy the original graph node with its def into a ONNX node format.
"""
tf_op = operator.raw_operator
attrs = _to_onnx_attrs(tf_op)

container.add_node(operator.type,
operator.input_full_names,
operator.output_full_names,
name=operator.full_name,
op_domain='ai.onnx.contrib',
op_version=1,
**attrs)
3 changes: 2 additions & 1 deletion keras2onnx/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
from .topology import Topology
from .common.utils import set_logger_level, k2o_logger
from .funcbook import set_converter
from ._tf_utils import tsname_to_node
from ._builtin import register_direct_tf_ops
from ._parser_1x import build_opdict_from_keras
from ._parse_tf import tsname_to_node, build_layer_output_from_model
from ._parser_tf import build_layer_output_from_model


def convert_keras(model, name=None, doc_string='', target_opset=None,
Expand Down
14 changes: 7 additions & 7 deletions keras2onnx/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from .funcbook import get_converter, set_converter
from ._consts import TYPES
from ._tf_ops import pass_thru_converter
from ._parse_tf import (infer_variable_type, LayerInfo, is_placeholder_node,
tsname_to_node, on_parsing_keras_layer_v2, adjust_input_batch_size as _adjust_input_batch_size)
from ._parser_tf import (infer_variable_type, LayerInfo, is_placeholder_node,
tsname_to_node, on_parsing_keras_layer_v2, adjust_input_batch_size as _adjust_input_batch_size)
from ._parser_1x import (extract_inbound_nodes,
list_input_tensors, list_input_mask, list_output_mask,
list_output_tensors, list_input_shapes, list_output_shapes, on_parsing_keras_layer)
Expand Down Expand Up @@ -294,15 +294,15 @@ def _check_tfnode_converter_availability(graph, node):


def _check_tfnodes_converter_availability(graph, nodelist, debug_mode):
status = True
for n_ in nodelist:
if not _check_tfnode_converter_availability(graph, n_):
k2o_logger().warning(
"The tf.op node {} of type {} cannot be converted".format(n_.name, n_.type))
if debug_mode:
continue
return False
"WARN: No corresponding ONNX op matches the tf.op node {} of type {}".format(n_.name, n_.type) +
"\n The generated ONNX model needs run with the custom op supports.")
status = False

return True
return status


def _on_parsing_tf_nodes(graph, nodelist, varset, debug_mode):
Expand Down