Skip to content
This repository has been archived by the owner on Oct 13, 2021. It is now read-only.

Support tf.nn.relu6 #506

Merged
merged 3 commits into from
May 31, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions keras2onnx/_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1723,6 +1723,16 @@ def convert_tf_read_variable_op(scope, operator, container):
name=operator.full_name)


@converter_func(TYPES.Relu6)
def convert_tf_relu6(scope, operator, container):
oopb = OnnxOperatorBuilder(container, scope)
oopb.apply_op_with_output("apply_relu6",
operator.input_full_names,
operator.output_full_names,
name=operator.full_name + '_clip',
dtype=operator.raw_operator.input.dtype.as_numpy_dtype)


@converter_func(TYPES.Slice)
def convert_tf_slice(scope, operator, container):
oopb = OnnxOperatorBuilder(container, scope)
Expand Down
1 change: 1 addition & 0 deletions keras2onnx/_consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class TYPES:
Range = 'Range'
ReadVariableOp = 'ReadVariableOp'
RealDiv = 'RealDiv'
Relu6 = 'Relu6'
Reshape = 'Reshape'
ResizeBilinear = 'ResizeBilinear'
ResizeNearestNeighbor = 'ResizeNearestNeighbor'
Expand Down
10 changes: 5 additions & 5 deletions keras2onnx/ke2onnx/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
###############################################################################
import tensorflow as tf
from ..proto import keras, is_tf_keras
from ..common.onnx_ops import apply_elu, apply_hard_sigmoid, apply_relu, apply_sigmoid, apply_tanh, \
apply_softmax, apply_identity, apply_selu, apply_clip, apply_mul
from ..common.onnx_ops import apply_elu, apply_hard_sigmoid, apply_relu, apply_relu6, apply_sigmoid, apply_tanh, \
apply_softmax, apply_identity, apply_selu, apply_mul

activation_get = keras.activations.get

Expand All @@ -30,6 +30,7 @@
tf.nn.sigmoid: apply_sigmoid,
tf.nn.softmax: apply_softmax,
tf.nn.relu: apply_relu,
tf.nn.relu6: apply_relu6,
tf.nn.elu: apply_elu,
tf.nn.tanh: apply_tanh}

Expand All @@ -56,9 +57,8 @@ def convert_keras_activation(scope, operator, container):
apply_selu(scope, input_name, output_name, container, alpha=1.673263, gamma=1.050701)
elif activation in [relu6] or activation.__name__ == 'relu6':
# relu6(x) = min(relu(x), 6)
apply_relu(scope, input_name, output_name + "_relu6", container)
apply_clip(scope, output_name + "_relu6", output_name, container,
min=0, max=6)
apply_relu6(scope, input_name, output_name, container,
dtype=operator.raw_operator.input.dtype.as_numpy_dtype)
elif activation.__name__ in ['swish']:
apply_sigmoid(scope, input_name, output_name + '_sig', container)
apply_mul(scope, [input_name, output_name + '_sig'], output_name, container)
Expand Down
18 changes: 18 additions & 0 deletions keras2onnx/ke2onnx/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,26 @@
# license information.
###############################################################################

from ..proto import keras
from ..proto.tfcompat import tensorflow as tf
from ..common.onnx_ops import apply_relu6, apply_softmax
from .activation import activation_map
activation_get = keras.activations.get


def get_permutation_config(n_dims):
input_perm_axes = [0, n_dims + 1] + list(range(1, n_dims + 1))
output_perm_axes = [0] + list(range(2, n_dims + 2)) + [1]
return input_perm_axes, output_perm_axes


def activation_process(scope, operator, container, biased_tensor_name):
# Create an activation function node and apply activation function to the intermediate tensor
apply_activation_function = activation_map[operator.raw_operator.activation]
if operator.raw_operator.activation in [activation_get('softmax'), keras.activations.softmax]:
apply_softmax(scope, biased_tensor_name, operator.outputs[0].full_name, container, axis=-1)
elif operator.raw_operator.activation in [tf.nn.relu6]:
apply_relu6(scope, biased_tensor_name, operator.outputs[0].full_name, container,
dtype=operator.raw_operator.input.dtype.as_numpy_dtype)
else:
apply_activation_function(scope, biased_tensor_name, operator.outputs[0].full_name, container)
10 changes: 3 additions & 7 deletions keras2onnx/ke2onnx/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
# license information.
###############################################################################
import numpy
from .activation import activation_map
from .common import activation_process
from ..proto import keras
from ..proto import onnx_proto
from ..common.utils import count_dynamic_dim
from ..common.onnx_ops import (apply_identity, apply_pad, apply_softmax,
from ..common.onnx_ops import (apply_identity, apply_pad,
apply_transpose, apply_mul, apply_sigmoid)

activation_get = keras.activations.get
Expand Down Expand Up @@ -202,11 +202,7 @@ def convert_keras_conv_core(scope, operator, container, is_transpose, n_dims, in
apply_mul(scope, [transpose_output_name, transpose_output_name + '_sig'], operator.outputs[0].full_name,
container)
else:
apply_activation_function = activation_map[op.activation]
if op.activation in [activation_get('softmax'), keras.activations.softmax]:
apply_softmax(scope, transpose_output_name, operator.outputs[0].full_name, container, axis=-1)
else:
apply_activation_function(scope, transpose_output_name, operator.outputs[0].full_name, container)
activation_process(scope, operator, container, transpose_output_name)


def get_converter_config(dims, is_conv_transpose):
Expand Down
11 changes: 3 additions & 8 deletions keras2onnx/ke2onnx/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
# license information.
###############################################################################
import numpy as np
from .common import activation_process
from ..proto import onnx_proto, keras
from ..common.onnx_ops import apply_softmax, apply_add, OnnxOperatorBuilder
from .activation import activation_map
from ..common.onnx_ops import apply_add, OnnxOperatorBuilder
activation_get = keras.activations.get


Expand Down Expand Up @@ -40,9 +40,4 @@ def convert_keras_dense(scope, operator, container):
apply_add(scope, transformed_tensor_name + [bias_name], biased_tensor_name, container,
axis=-1, broadcast=1)

# Create an activation function node and apply activation function to the intermediate tensor
apply_activation_function = activation_map[operator.raw_operator.activation]
if operator.raw_operator.activation in [activation_get('softmax'), keras.activations.softmax]:
apply_softmax(scope, biased_tensor_name, operator.outputs[0].full_name, container, axis=-1)
else:
apply_activation_function(scope, biased_tensor_name, operator.outputs[0].full_name, container)
activation_process(scope, operator, container, biased_tensor_name)
2 changes: 1 addition & 1 deletion tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1374,7 +1374,7 @@ def test_Softmax(advanced_activation_runner):


def test_tf_nn_activation(runner):
for activation in [tf.nn.relu, 'relu']:
for activation in [tf.nn.relu, 'relu', tf.nn.relu6]:
model = keras.Sequential([
Dense(64, activation=activation, input_shape=[10]),
Dense(64, activation=activation),
Expand Down