Skip to content

Commit

Permalink
Call super().build(input_shape) instead of self.built = True in a…
Browse files Browse the repository at this point in the history
…ll Keras layers.

Within `build()`, some Keras layers where calling `super().build(input_shape)` while some were calling `self.built = True`. This would result in a different config when serializing whereby layers doing `self.built = True` would not have a `build_config`.

This change makes it consistent between all the layers as well as consistent with Keras 3.

Note that some layers need to call `Layer.build(self, input_shape)` directly to bypass some class' `build()` but still populate the information for the `build_config`.

PiperOrigin-RevId: 678454186
  • Loading branch information
hertschuh authored and tensorflower-gardener committed Sep 25, 2024
1 parent d41c2c1 commit ab9494c
Show file tree
Hide file tree
Showing 46 changed files with 95 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ tf_class {
}
member_method {
name: "build"
argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "build_from_config"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ tf_class {
}
member_method {
name: "build"
argspec: "args=[\'self\', \'inputs_shape\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "build_from_config"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ tf_class {
}
member_method {
name: "build"
argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "build_from_config"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ tf_class {
}
member_method {
name: "build"
argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "build_from_config"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ tf_class {
}
member_method {
name: "build"
argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "build_from_config"
Expand Down
4 changes: 2 additions & 2 deletions tf_keras/engine/base_layer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1877,9 +1877,9 @@ class AddLayer(base_layer.Layer):
Useful for testing a layer with a variable
"""

def build(self, _):
def build(self, input_shape):
self.v = self.add_weight("v", (), initializer="ones")
self.built = True
super().build(input_shape)

def call(self, inputs):
return inputs + self.v
Expand Down
2 changes: 1 addition & 1 deletion tf_keras/engine/functional_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def build(self, input_shape):
self.a, [[1.0]], name="unconditional_update"
)
)
self.built = True
super().build(input_shape)

def call(self, inputs):
self.add_update(
Expand Down
4 changes: 2 additions & 2 deletions tf_keras/engine/training_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ def __init__(self, input_shape=None, **kwargs):

def build(self, input_shape):
self._nested_layer = NestedReturnTraining()
self.built = True
super().build(input_shape)

def call(self, inputs):
return self._nested_layer(inputs)
Expand Down Expand Up @@ -4148,7 +4148,7 @@ def build(self, input_shape):
self.a = self.add_weight(
"a", (1, 1), initializer="ones", trainable=False
)
self.built = True
super().build(input_shape)

def call(self, inputs):
self.add_metric(
Expand Down
2 changes: 1 addition & 1 deletion tf_keras/export/export_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def __init__(self, **kwargs):

def build(self, input_shape):
self.dense.build(input_shape)
self.built = True
super().build(input_shape)

def call(self, x):
return self.dense(x)
Expand Down
2 changes: 1 addition & 1 deletion tf_keras/layers/activation/prelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def build(self, input_shape):
if i not in self.shared_axes:
axes[i] = input_shape[i]
self.input_spec = InputSpec(ndim=len(input_shape), axes=axes)
self.built = True
super().build(input_shape)

def call(self, inputs):
pos = backend.relu(inputs)
Expand Down
3 changes: 2 additions & 1 deletion tf_keras/layers/attention/base_dense_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ def build(self, input_shape):
# be purely stateless, with no reference to any variable.
if self.dropout > 0:
super().build(input_shape)
self.built = True
else:
base_layer.Layer.build(self, input_shape)

def _calculate_scores(self, query, key):
"""Calculates attention scores.
Expand Down
5 changes: 5 additions & 0 deletions tf_keras/layers/convolutional/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ py_library(
"//tf_keras:constraints",
"//tf_keras:regularizers",
"//tf_keras/dtensor:utils",
"//tf_keras/engine:base_layer",
"//tf_keras/engine:input_spec",
"//tf_keras/initializers",
"//tf_keras/utils:engine_utils",
Expand All @@ -132,6 +133,7 @@ py_library(
"//tf_keras:constraints",
"//tf_keras:regularizers",
"//tf_keras/dtensor:utils",
"//tf_keras/engine:base_layer",
"//tf_keras/engine:input_spec",
"//tf_keras/initializers",
"//tf_keras/utils:engine_utils",
Expand All @@ -149,6 +151,7 @@ py_library(
"//tf_keras:constraints",
"//tf_keras:regularizers",
"//tf_keras/dtensor:utils",
"//tf_keras/engine:base_layer",
"//tf_keras/engine:input_spec",
"//tf_keras/initializers",
"//tf_keras/utils:engine_utils",
Expand All @@ -165,6 +168,7 @@ py_library(
"//tf_keras:activations",
"//tf_keras:constraints",
"//tf_keras:regularizers",
"//tf_keras/engine:base_layer",
"//tf_keras/engine:input_spec",
"//tf_keras/initializers",
],
Expand Down Expand Up @@ -209,6 +213,7 @@ py_library(
"//:expect_tensorflow_installed",
"//tf_keras:constraints",
"//tf_keras:regularizers",
"//tf_keras/engine:base_layer",
"//tf_keras/engine:input_spec",
"//tf_keras/initializers",
],
Expand Down
2 changes: 1 addition & 1 deletion tf_keras/layers/convolutional/base_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def build(self, input_shape):
self.input_spec = InputSpec(
min_ndim=self.rank + 2, axes={channel_axis: input_channel}
)
self.built = True
super().build(input_shape)

def convolution_op(self, inputs, kernel):
if self.padding == "causal":
Expand Down
4 changes: 3 additions & 1 deletion tf_keras/layers/convolutional/base_depthwise_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tf_keras import constraints
from tf_keras import initializers
from tf_keras import regularizers
from tf_keras.engine.base_layer import Layer
from tf_keras.engine.input_spec import InputSpec
from tf_keras.layers.convolutional.base_conv import Conv

Expand Down Expand Up @@ -202,7 +203,8 @@ def build(self, input_shape):
self.input_spec = InputSpec(
min_ndim=self.rank + 2, axes={channel_axis: input_dim}
)
self.built = True
# Call Layer.build() to skip Conv.build() which we override here.
Layer.build(self, input_shape)

def call(self, inputs):
raise NotImplementedError
Expand Down
4 changes: 3 additions & 1 deletion tf_keras/layers/convolutional/base_separable_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from tf_keras import constraints
from tf_keras import initializers
from tf_keras import regularizers
from tf_keras.engine.base_layer import Layer
from tf_keras.engine.input_spec import InputSpec
from tf_keras.layers.convolutional.base_conv import Conv

Expand Down Expand Up @@ -203,7 +204,8 @@ def build(self, input_shape):
)
else:
self.bias = None
self.built = True
# Call Layer.build() to skip Conv.build() which we override here.
Layer.build(self, input_shape)

def call(self, inputs):
raise NotImplementedError
Expand Down
4 changes: 3 additions & 1 deletion tf_keras/layers/convolutional/conv1d_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from tf_keras import initializers
from tf_keras import regularizers
from tf_keras.dtensor import utils
from tf_keras.engine.base_layer import Layer
from tf_keras.engine.input_spec import InputSpec
from tf_keras.layers.convolutional.conv1d import Conv1D
from tf_keras.utils import conv_utils
Expand Down Expand Up @@ -214,7 +215,8 @@ def build(self, input_shape):
)
else:
self.bias = None
self.built = True
# Call Layer.build() to skip Conv.build() which we override here.
Layer.build(self, input_shape)

def call(self, inputs):
inputs_shape = tf.shape(inputs)
Expand Down
4 changes: 3 additions & 1 deletion tf_keras/layers/convolutional/conv2d_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from tf_keras import initializers
from tf_keras import regularizers
from tf_keras.dtensor import utils
from tf_keras.engine.base_layer import Layer
from tf_keras.engine.input_spec import InputSpec
from tf_keras.layers.convolutional.conv2d import Conv2D
from tf_keras.utils import conv_utils
Expand Down Expand Up @@ -240,7 +241,8 @@ def build(self, input_shape):
)
else:
self.bias = None
self.built = True
# Call Layer.build() to skip Conv.build() which we override here.
Layer.build(self, input_shape)

def call(self, inputs):
inputs_shape = tf.shape(inputs)
Expand Down
4 changes: 3 additions & 1 deletion tf_keras/layers/convolutional/conv3d_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from tf_keras import initializers
from tf_keras import regularizers
from tf_keras.dtensor import utils
from tf_keras.engine.base_layer import Layer
from tf_keras.engine.input_spec import InputSpec
from tf_keras.layers.convolutional.conv3d import Conv3D
from tf_keras.utils import conv_utils
Expand Down Expand Up @@ -247,7 +248,8 @@ def build(self, input_shape):
)
else:
self.bias = None
self.built = True
# Call Layer.build() to skip Conv.build() which we override here.
Layer.build(self, input_shape)

def call(self, inputs):
inputs_shape = tf.shape(inputs)
Expand Down
2 changes: 1 addition & 1 deletion tf_keras/layers/core/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def build(self, input_shape):
)
else:
self.bias = None
self.built = True
super().build(input_shape)

def call(self, inputs):
if inputs.dtype.base_dtype != self._compute_dtype_object.base_dtype:
Expand Down
2 changes: 1 addition & 1 deletion tf_keras/layers/core/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def build(self, input_shape=None):
constraint=self.embeddings_constraint,
experimental_autocast=False,
)
self.built = True
super().build(input_shape)

def compute_mask(self, inputs, mask=None):
if not self.mask_zero:
Expand Down
2 changes: 1 addition & 1 deletion tf_keras/layers/locally_connected/locally_connected1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def build(self, input_shape):
self.input_spec = InputSpec(ndim=3, axes={1: input_dim})
else:
self.input_spec = InputSpec(ndim=3, axes={-1: input_dim})
self.built = True
super().build(input_shape)

@tf_utils.shape_type_conversion
def compute_output_shape(self, input_shape):
Expand Down
2 changes: 1 addition & 1 deletion tf_keras/layers/locally_connected/locally_connected2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def build(self, input_shape):
self.input_spec = InputSpec(ndim=4, axes={1: input_filter})
else:
self.input_spec = InputSpec(ndim=4, axes={-1: input_filter})
self.built = True
super().build(input_shape)

@tf_utils.shape_type_conversion
def compute_output_shape(self, input_shape):
Expand Down
2 changes: 1 addition & 1 deletion tf_keras/layers/normalization/batch_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ def _renorm_variable(name, shape, initializer="zeros"):
finally:
if partitioner:
self._scope.set_partitioner(partitioner)
self.built = True
super().build(input_shape)

def call(self, inputs, training=None, mask=None):
inputs = tf.cast(inputs, self.compute_dtype)
Expand Down
2 changes: 1 addition & 1 deletion tf_keras/layers/normalization/layer_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def build(self, input_shape):
self.beta = None

self._fused = self._fused_can_be_used(rank)
self.built = True
super().build(input_shape)

def call(self, inputs):
# TODO(b/229545225): Remove the RaggedTensor check.
Expand Down
2 changes: 1 addition & 1 deletion tf_keras/layers/rnn/abstract_rnn_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def build(self, input_shape):
shape=(self.units, self.units),
initializer='uniform',
name='recurrent_kernel')
self.built = True
super().build(input_shape)
def call(self, inputs, states):
prev_output = states[0]
Expand Down
1 change: 0 additions & 1 deletion tf_keras/layers/rnn/base_conv_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,6 @@ def bias_initializer(_, *args, **kwargs):
)
else:
self.bias = None
self.built = True

def call(self, inputs, states, training=None):
h_tm1 = states[0] # previous memory state
Expand Down
4 changes: 3 additions & 1 deletion tf_keras/layers/rnn/base_conv_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from tf_keras import backend
from tf_keras.engine import base_layer
from tf_keras.engine.base_layer import Layer
from tf_keras.engine.input_spec import InputSpec
from tf_keras.layers.rnn.base_rnn import RNN
from tf_keras.utils import conv_utils
Expand Down Expand Up @@ -207,6 +208,8 @@ def compute_output_shape(self, input_shape):

@tf_utils.shape_type_conversion
def build(self, input_shape):
# Call Layer.build() to skip RNN.build() which we override here.
Layer.build(self, input_shape)
# Note input_shape will be list of shapes of initial states and
# constants if these are passed in __call__.
if self._num_constants is not None:
Expand Down Expand Up @@ -263,7 +266,6 @@ def build(self, input_shape):
]
if self.stateful:
self.reset_states()
self.built = True

def get_initial_state(self, inputs):
# (samples, timesteps, img_dims..., filters)
Expand Down
2 changes: 1 addition & 1 deletion tf_keras/layers/rnn/base_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def build(self, input_shape):
shape=(self.units, self.units),
initializer='uniform',
name='recurrent_kernel')
self.built = True
super().build(input_shape)
def call(self, inputs, states):
prev_output = states[0]
Expand Down
Loading

0 comments on commit ab9494c

Please sign in to comment.