From ab9494c41c73cb290837d889439069a24fd31400 Mon Sep 17 00:00:00 2001 From: Fabien Hertschuh Date: Tue, 24 Sep 2024 17:21:57 -0700 Subject: [PATCH] Call `super().build(input_shape)` instead of `self.built = True` in all Keras layers. Within `build()`, some Keras layers where calling `super().build(input_shape)` while some were calling `self.built = True`. This would result in a different config when serializing whereby layers doing `self.built = True` would not have a `build_config`. This change makes it consistent between all the layers as well as consistent with Keras 3. Note that some layers need to call `Layer.build(self, input_shape)` directly to bypass some class' `build()` but still populate the information for the `build_config`. PiperOrigin-RevId: 678454186 --- ...al__.legacy.rnn_cell.-device-wrapper.pbtxt | 2 +- ...l__.legacy.rnn_cell.-dropout-wrapper.pbtxt | 2 +- ...__.legacy.rnn_cell.-multi-r-n-n-cell.pbtxt | 2 +- ...ternal__.legacy.rnn_cell.-r-n-n-cell.pbtxt | 2 +- ...__.legacy.rnn_cell.-residual-wrapper.pbtxt | 2 +- tf_keras/engine/base_layer_test.py | 4 +- tf_keras/engine/functional_test.py | 2 +- tf_keras/engine/training_test.py | 4 +- tf_keras/export/export_lib_test.py | 2 +- tf_keras/layers/activation/prelu.py | 2 +- .../layers/attention/base_dense_attention.py | 3 +- tf_keras/layers/convolutional/BUILD | 5 +++ tf_keras/layers/convolutional/base_conv.py | 2 +- .../convolutional/base_depthwise_conv.py | 4 +- .../convolutional/base_separable_conv.py | 4 +- .../layers/convolutional/conv1d_transpose.py | 4 +- .../layers/convolutional/conv2d_transpose.py | 4 +- .../layers/convolutional/conv3d_transpose.py | 4 +- tf_keras/layers/core/dense.py | 2 +- tf_keras/layers/core/embedding.py | 2 +- .../locally_connected/locally_connected1d.py | 2 +- .../locally_connected/locally_connected2d.py | 2 +- .../normalization/batch_normalization.py | 2 +- .../normalization/layer_normalization.py | 2 +- tf_keras/layers/rnn/abstract_rnn_cell.py | 2 +- tf_keras/layers/rnn/base_conv_lstm.py | 1 - tf_keras/layers/rnn/base_conv_rnn.py | 4 +- tf_keras/layers/rnn/base_rnn.py | 2 +- tf_keras/layers/rnn/base_rnn_test.py | 12 ++--- tf_keras/layers/rnn/base_wrapper.py | 2 +- tf_keras/layers/rnn/bidirectional.py | 3 +- tf_keras/layers/rnn/bidirectional_test.py | 2 +- tf_keras/layers/rnn/cell_wrappers.py | 6 +-- tf_keras/layers/rnn/cudnn_gru.py | 2 - tf_keras/layers/rnn/cudnn_lstm.py | 2 - tf_keras/layers/rnn/gru.py | 1 - tf_keras/layers/rnn/legacy_cell_wrappers.py | 6 +-- tf_keras/layers/rnn/legacy_cells.py | 45 +++++++++---------- tf_keras/layers/rnn/lstm.py | 1 - tf_keras/layers/rnn/simple_rnn.py | 1 - tf_keras/layers/rnn/stacked_rnn_cells.py | 2 +- tf_keras/layers/rnn/time_distributed.py | 1 - tf_keras/legacy_tf_layers/base_test.py | 8 ++-- tf_keras/mixed_precision/test_util.py | 9 ++-- tf_keras/premade_models/linear.py | 3 +- .../legacy/saved_model/saved_model_test.py | 2 +- 46 files changed, 95 insertions(+), 88 deletions(-) diff --git a/tf_keras/api/golden/v1/tensorflow.keras.__internal__.legacy.rnn_cell.-device-wrapper.pbtxt b/tf_keras/api/golden/v1/tensorflow.keras.__internal__.legacy.rnn_cell.-device-wrapper.pbtxt index a02349445..44c8aea1c 100644 --- a/tf_keras/api/golden/v1/tensorflow.keras.__internal__.legacy.rnn_cell.-device-wrapper.pbtxt +++ b/tf_keras/api/golden/v1/tensorflow.keras.__internal__.legacy.rnn_cell.-device-wrapper.pbtxt @@ -177,7 +177,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "build_from_config" diff --git a/tf_keras/api/golden/v1/tensorflow.keras.__internal__.legacy.rnn_cell.-dropout-wrapper.pbtxt b/tf_keras/api/golden/v1/tensorflow.keras.__internal__.legacy.rnn_cell.-dropout-wrapper.pbtxt index 8bdb21b1e..fc10ab1ac 100644 --- a/tf_keras/api/golden/v1/tensorflow.keras.__internal__.legacy.rnn_cell.-dropout-wrapper.pbtxt +++ b/tf_keras/api/golden/v1/tensorflow.keras.__internal__.legacy.rnn_cell.-dropout-wrapper.pbtxt @@ -181,7 +181,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'inputs_shape\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "build_from_config" diff --git a/tf_keras/api/golden/v1/tensorflow.keras.__internal__.legacy.rnn_cell.-multi-r-n-n-cell.pbtxt b/tf_keras/api/golden/v1/tensorflow.keras.__internal__.legacy.rnn_cell.-multi-r-n-n-cell.pbtxt index ddfafa107..7b42d9678 100644 --- a/tf_keras/api/golden/v1/tensorflow.keras.__internal__.legacy.rnn_cell.-multi-r-n-n-cell.pbtxt +++ b/tf_keras/api/golden/v1/tensorflow.keras.__internal__.legacy.rnn_cell.-multi-r-n-n-cell.pbtxt @@ -176,7 +176,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "build_from_config" diff --git a/tf_keras/api/golden/v1/tensorflow.keras.__internal__.legacy.rnn_cell.-r-n-n-cell.pbtxt b/tf_keras/api/golden/v1/tensorflow.keras.__internal__.legacy.rnn_cell.-r-n-n-cell.pbtxt index cd510f628..31d12d256 100644 --- a/tf_keras/api/golden/v1/tensorflow.keras.__internal__.legacy.rnn_cell.-r-n-n-cell.pbtxt +++ b/tf_keras/api/golden/v1/tensorflow.keras.__internal__.legacy.rnn_cell.-r-n-n-cell.pbtxt @@ -175,7 +175,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "build_from_config" diff --git a/tf_keras/api/golden/v1/tensorflow.keras.__internal__.legacy.rnn_cell.-residual-wrapper.pbtxt b/tf_keras/api/golden/v1/tensorflow.keras.__internal__.legacy.rnn_cell.-residual-wrapper.pbtxt index 2a923c531..1ad9161ce 100644 --- a/tf_keras/api/golden/v1/tensorflow.keras.__internal__.legacy.rnn_cell.-residual-wrapper.pbtxt +++ b/tf_keras/api/golden/v1/tensorflow.keras.__internal__.legacy.rnn_cell.-residual-wrapper.pbtxt @@ -177,7 +177,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "build_from_config" diff --git a/tf_keras/engine/base_layer_test.py b/tf_keras/engine/base_layer_test.py index e74b7b13c..19d2700b5 100644 --- a/tf_keras/engine/base_layer_test.py +++ b/tf_keras/engine/base_layer_test.py @@ -1877,9 +1877,9 @@ class AddLayer(base_layer.Layer): Useful for testing a layer with a variable """ - def build(self, _): + def build(self, input_shape): self.v = self.add_weight("v", (), initializer="ones") - self.built = True + super().build(input_shape) def call(self, inputs): return inputs + self.v diff --git a/tf_keras/engine/functional_test.py b/tf_keras/engine/functional_test.py index a8f0a004d..786cc4297 100644 --- a/tf_keras/engine/functional_test.py +++ b/tf_keras/engine/functional_test.py @@ -69,7 +69,7 @@ def build(self, input_shape): self.a, [[1.0]], name="unconditional_update" ) ) - self.built = True + super().build(input_shape) def call(self, inputs): self.add_update( diff --git a/tf_keras/engine/training_test.py b/tf_keras/engine/training_test.py index 27387c5b3..a480666bf 100644 --- a/tf_keras/engine/training_test.py +++ b/tf_keras/engine/training_test.py @@ -527,7 +527,7 @@ def __init__(self, input_shape=None, **kwargs): def build(self, input_shape): self._nested_layer = NestedReturnTraining() - self.built = True + super().build(input_shape) def call(self, inputs): return self._nested_layer(inputs) @@ -4148,7 +4148,7 @@ def build(self, input_shape): self.a = self.add_weight( "a", (1, 1), initializer="ones", trainable=False ) - self.built = True + super().build(input_shape) def call(self, inputs): self.add_metric( diff --git a/tf_keras/export/export_lib_test.py b/tf_keras/export/export_lib_test.py index 62d11a2e2..79dfc4aa3 100644 --- a/tf_keras/export/export_lib_test.py +++ b/tf_keras/export/export_lib_test.py @@ -396,7 +396,7 @@ def __init__(self, **kwargs): def build(self, input_shape): self.dense.build(input_shape) - self.built = True + super().build(input_shape) def call(self, x): return self.dense(x) diff --git a/tf_keras/layers/activation/prelu.py b/tf_keras/layers/activation/prelu.py index d3d6384ef..0dddf5e9f 100644 --- a/tf_keras/layers/activation/prelu.py +++ b/tf_keras/layers/activation/prelu.py @@ -102,7 +102,7 @@ def build(self, input_shape): if i not in self.shared_axes: axes[i] = input_shape[i] self.input_spec = InputSpec(ndim=len(input_shape), axes=axes) - self.built = True + super().build(input_shape) def call(self, inputs): pos = backend.relu(inputs) diff --git a/tf_keras/layers/attention/base_dense_attention.py b/tf_keras/layers/attention/base_dense_attention.py index 7a22a1167..3724b2c2f 100644 --- a/tf_keras/layers/attention/base_dense_attention.py +++ b/tf_keras/layers/attention/base_dense_attention.py @@ -86,7 +86,8 @@ def build(self, input_shape): # be purely stateless, with no reference to any variable. if self.dropout > 0: super().build(input_shape) - self.built = True + else: + base_layer.Layer.build(self, input_shape) def _calculate_scores(self, query, key): """Calculates attention scores. diff --git a/tf_keras/layers/convolutional/BUILD b/tf_keras/layers/convolutional/BUILD index c964a0d45..5c060fc89 100644 --- a/tf_keras/layers/convolutional/BUILD +++ b/tf_keras/layers/convolutional/BUILD @@ -114,6 +114,7 @@ py_library( "//tf_keras:constraints", "//tf_keras:regularizers", "//tf_keras/dtensor:utils", + "//tf_keras/engine:base_layer", "//tf_keras/engine:input_spec", "//tf_keras/initializers", "//tf_keras/utils:engine_utils", @@ -132,6 +133,7 @@ py_library( "//tf_keras:constraints", "//tf_keras:regularizers", "//tf_keras/dtensor:utils", + "//tf_keras/engine:base_layer", "//tf_keras/engine:input_spec", "//tf_keras/initializers", "//tf_keras/utils:engine_utils", @@ -149,6 +151,7 @@ py_library( "//tf_keras:constraints", "//tf_keras:regularizers", "//tf_keras/dtensor:utils", + "//tf_keras/engine:base_layer", "//tf_keras/engine:input_spec", "//tf_keras/initializers", "//tf_keras/utils:engine_utils", @@ -165,6 +168,7 @@ py_library( "//tf_keras:activations", "//tf_keras:constraints", "//tf_keras:regularizers", + "//tf_keras/engine:base_layer", "//tf_keras/engine:input_spec", "//tf_keras/initializers", ], @@ -209,6 +213,7 @@ py_library( "//:expect_tensorflow_installed", "//tf_keras:constraints", "//tf_keras:regularizers", + "//tf_keras/engine:base_layer", "//tf_keras/engine:input_spec", "//tf_keras/initializers", ], diff --git a/tf_keras/layers/convolutional/base_conv.py b/tf_keras/layers/convolutional/base_conv.py index b88695235..7c098ea6a 100644 --- a/tf_keras/layers/convolutional/base_conv.py +++ b/tf_keras/layers/convolutional/base_conv.py @@ -248,7 +248,7 @@ def build(self, input_shape): self.input_spec = InputSpec( min_ndim=self.rank + 2, axes={channel_axis: input_channel} ) - self.built = True + super().build(input_shape) def convolution_op(self, inputs, kernel): if self.padding == "causal": diff --git a/tf_keras/layers/convolutional/base_depthwise_conv.py b/tf_keras/layers/convolutional/base_depthwise_conv.py index 497289e95..d6edc9a11 100644 --- a/tf_keras/layers/convolutional/base_depthwise_conv.py +++ b/tf_keras/layers/convolutional/base_depthwise_conv.py @@ -20,6 +20,7 @@ from tf_keras import constraints from tf_keras import initializers from tf_keras import regularizers +from tf_keras.engine.base_layer import Layer from tf_keras.engine.input_spec import InputSpec from tf_keras.layers.convolutional.base_conv import Conv @@ -202,7 +203,8 @@ def build(self, input_shape): self.input_spec = InputSpec( min_ndim=self.rank + 2, axes={channel_axis: input_dim} ) - self.built = True + # Call Layer.build() to skip Conv.build() which we override here. + Layer.build(self, input_shape) def call(self, inputs): raise NotImplementedError diff --git a/tf_keras/layers/convolutional/base_separable_conv.py b/tf_keras/layers/convolutional/base_separable_conv.py index d1a81b2cc..baa0242b1 100644 --- a/tf_keras/layers/convolutional/base_separable_conv.py +++ b/tf_keras/layers/convolutional/base_separable_conv.py @@ -21,6 +21,7 @@ from tf_keras import constraints from tf_keras import initializers from tf_keras import regularizers +from tf_keras.engine.base_layer import Layer from tf_keras.engine.input_spec import InputSpec from tf_keras.layers.convolutional.base_conv import Conv @@ -203,7 +204,8 @@ def build(self, input_shape): ) else: self.bias = None - self.built = True + # Call Layer.build() to skip Conv.build() which we override here. + Layer.build(self, input_shape) def call(self, inputs): raise NotImplementedError diff --git a/tf_keras/layers/convolutional/conv1d_transpose.py b/tf_keras/layers/convolutional/conv1d_transpose.py index 85d839aa4..177b190d5 100644 --- a/tf_keras/layers/convolutional/conv1d_transpose.py +++ b/tf_keras/layers/convolutional/conv1d_transpose.py @@ -22,6 +22,7 @@ from tf_keras import initializers from tf_keras import regularizers from tf_keras.dtensor import utils +from tf_keras.engine.base_layer import Layer from tf_keras.engine.input_spec import InputSpec from tf_keras.layers.convolutional.conv1d import Conv1D from tf_keras.utils import conv_utils @@ -214,7 +215,8 @@ def build(self, input_shape): ) else: self.bias = None - self.built = True + # Call Layer.build() to skip Conv.build() which we override here. + Layer.build(self, input_shape) def call(self, inputs): inputs_shape = tf.shape(inputs) diff --git a/tf_keras/layers/convolutional/conv2d_transpose.py b/tf_keras/layers/convolutional/conv2d_transpose.py index f508ffbce..2b2d943ef 100644 --- a/tf_keras/layers/convolutional/conv2d_transpose.py +++ b/tf_keras/layers/convolutional/conv2d_transpose.py @@ -23,6 +23,7 @@ from tf_keras import initializers from tf_keras import regularizers from tf_keras.dtensor import utils +from tf_keras.engine.base_layer import Layer from tf_keras.engine.input_spec import InputSpec from tf_keras.layers.convolutional.conv2d import Conv2D from tf_keras.utils import conv_utils @@ -240,7 +241,8 @@ def build(self, input_shape): ) else: self.bias = None - self.built = True + # Call Layer.build() to skip Conv.build() which we override here. + Layer.build(self, input_shape) def call(self, inputs): inputs_shape = tf.shape(inputs) diff --git a/tf_keras/layers/convolutional/conv3d_transpose.py b/tf_keras/layers/convolutional/conv3d_transpose.py index c166e4900..8fce2cc5e 100644 --- a/tf_keras/layers/convolutional/conv3d_transpose.py +++ b/tf_keras/layers/convolutional/conv3d_transpose.py @@ -22,6 +22,7 @@ from tf_keras import initializers from tf_keras import regularizers from tf_keras.dtensor import utils +from tf_keras.engine.base_layer import Layer from tf_keras.engine.input_spec import InputSpec from tf_keras.layers.convolutional.conv3d import Conv3D from tf_keras.utils import conv_utils @@ -247,7 +248,8 @@ def build(self, input_shape): ) else: self.bias = None - self.built = True + # Call Layer.build() to skip Conv.build() which we override here. + Layer.build(self, input_shape) def call(self, inputs): inputs_shape = tf.shape(inputs) diff --git a/tf_keras/layers/core/dense.py b/tf_keras/layers/core/dense.py index e3383cdfd..06e6e5d75 100644 --- a/tf_keras/layers/core/dense.py +++ b/tf_keras/layers/core/dense.py @@ -174,7 +174,7 @@ def build(self, input_shape): ) else: self.bias = None - self.built = True + super().build(input_shape) def call(self, inputs): if inputs.dtype.base_dtype != self._compute_dtype_object.base_dtype: diff --git a/tf_keras/layers/core/embedding.py b/tf_keras/layers/core/embedding.py index 8cef7d332..7f9b59b06 100644 --- a/tf_keras/layers/core/embedding.py +++ b/tf_keras/layers/core/embedding.py @@ -185,7 +185,7 @@ def build(self, input_shape=None): constraint=self.embeddings_constraint, experimental_autocast=False, ) - self.built = True + super().build(input_shape) def compute_mask(self, inputs, mask=None): if not self.mask_zero: diff --git a/tf_keras/layers/locally_connected/locally_connected1d.py b/tf_keras/layers/locally_connected/locally_connected1d.py index ad5d202b1..d33268d63 100644 --- a/tf_keras/layers/locally_connected/locally_connected1d.py +++ b/tf_keras/layers/locally_connected/locally_connected1d.py @@ -284,7 +284,7 @@ def build(self, input_shape): self.input_spec = InputSpec(ndim=3, axes={1: input_dim}) else: self.input_spec = InputSpec(ndim=3, axes={-1: input_dim}) - self.built = True + super().build(input_shape) @tf_utils.shape_type_conversion def compute_output_shape(self, input_shape): diff --git a/tf_keras/layers/locally_connected/locally_connected2d.py b/tf_keras/layers/locally_connected/locally_connected2d.py index 8d3d0afc5..db9c59602 100644 --- a/tf_keras/layers/locally_connected/locally_connected2d.py +++ b/tf_keras/layers/locally_connected/locally_connected2d.py @@ -308,7 +308,7 @@ def build(self, input_shape): self.input_spec = InputSpec(ndim=4, axes={1: input_filter}) else: self.input_spec = InputSpec(ndim=4, axes={-1: input_filter}) - self.built = True + super().build(input_shape) @tf_utils.shape_type_conversion def compute_output_shape(self, input_shape): diff --git a/tf_keras/layers/normalization/batch_normalization.py b/tf_keras/layers/normalization/batch_normalization.py index b891846bb..e73b7b3ec 100644 --- a/tf_keras/layers/normalization/batch_normalization.py +++ b/tf_keras/layers/normalization/batch_normalization.py @@ -542,7 +542,7 @@ def _renorm_variable(name, shape, initializer="zeros"): finally: if partitioner: self._scope.set_partitioner(partitioner) - self.built = True + super().build(input_shape) def call(self, inputs, training=None, mask=None): inputs = tf.cast(inputs, self.compute_dtype) diff --git a/tf_keras/layers/normalization/layer_normalization.py b/tf_keras/layers/normalization/layer_normalization.py index e76919224..a34787b19 100644 --- a/tf_keras/layers/normalization/layer_normalization.py +++ b/tf_keras/layers/normalization/layer_normalization.py @@ -249,7 +249,7 @@ def build(self, input_shape): self.beta = None self._fused = self._fused_can_be_used(rank) - self.built = True + super().build(input_shape) def call(self, inputs): # TODO(b/229545225): Remove the RaggedTensor check. diff --git a/tf_keras/layers/rnn/abstract_rnn_cell.py b/tf_keras/layers/rnn/abstract_rnn_cell.py index 0c7cac053..98d94d3e8 100644 --- a/tf_keras/layers/rnn/abstract_rnn_cell.py +++ b/tf_keras/layers/rnn/abstract_rnn_cell.py @@ -56,7 +56,7 @@ def build(self, input_shape): shape=(self.units, self.units), initializer='uniform', name='recurrent_kernel') - self.built = True + super().build(input_shape) def call(self, inputs, states): prev_output = states[0] diff --git a/tf_keras/layers/rnn/base_conv_lstm.py b/tf_keras/layers/rnn/base_conv_lstm.py index c03e75029..6102d9aad 100644 --- a/tf_keras/layers/rnn/base_conv_lstm.py +++ b/tf_keras/layers/rnn/base_conv_lstm.py @@ -218,7 +218,6 @@ def bias_initializer(_, *args, **kwargs): ) else: self.bias = None - self.built = True def call(self, inputs, states, training=None): h_tm1 = states[0] # previous memory state diff --git a/tf_keras/layers/rnn/base_conv_rnn.py b/tf_keras/layers/rnn/base_conv_rnn.py index d395258c3..c0e651a38 100644 --- a/tf_keras/layers/rnn/base_conv_rnn.py +++ b/tf_keras/layers/rnn/base_conv_rnn.py @@ -20,6 +20,7 @@ from tf_keras import backend from tf_keras.engine import base_layer +from tf_keras.engine.base_layer import Layer from tf_keras.engine.input_spec import InputSpec from tf_keras.layers.rnn.base_rnn import RNN from tf_keras.utils import conv_utils @@ -207,6 +208,8 @@ def compute_output_shape(self, input_shape): @tf_utils.shape_type_conversion def build(self, input_shape): + # Call Layer.build() to skip RNN.build() which we override here. + Layer.build(self, input_shape) # Note input_shape will be list of shapes of initial states and # constants if these are passed in __call__. if self._num_constants is not None: @@ -263,7 +266,6 @@ def build(self, input_shape): ] if self.stateful: self.reset_states() - self.built = True def get_initial_state(self, inputs): # (samples, timesteps, img_dims..., filters) diff --git a/tf_keras/layers/rnn/base_rnn.py b/tf_keras/layers/rnn/base_rnn.py index 2a310c5f9..45bab492b 100644 --- a/tf_keras/layers/rnn/base_rnn.py +++ b/tf_keras/layers/rnn/base_rnn.py @@ -207,7 +207,7 @@ def build(self, input_shape): shape=(self.units, self.units), initializer='uniform', name='recurrent_kernel') - self.built = True + super().build(input_shape) def call(self, inputs, states): prev_output = states[0] diff --git a/tf_keras/layers/rnn/base_rnn_test.py b/tf_keras/layers/rnn/base_rnn_test.py index da21673f1..2e0fcf59b 100644 --- a/tf_keras/layers/rnn/base_rnn_test.py +++ b/tf_keras/layers/rnn/base_rnn_test.py @@ -154,7 +154,7 @@ def build(self, input_shape): initializer="uniform", name="recurrent_kernel", ) - self.built = True + super().build(input_shape) def call(self, inputs, states): prev_output = states[0] @@ -241,7 +241,7 @@ def build(self, input_shape): initializer="uniform", name="recurrent_kernel", ) - self.built = True + super().build(input_shape) def call(self, inputs, states): prev_output = states[0] @@ -720,7 +720,7 @@ def build(self, input_shape): initializer="uniform", name="recurrent_kernel", ) - self.built = True + super().build(input_shape) def call(self, inputs, states): prev_output = states[0] @@ -968,7 +968,7 @@ def output_size(self): def build(self, input_shape): self.cell.build(input_shape) - self.built = True + super().build(input_shape) def get_initial_state( self, inputs=None, batch_size=None, dtype=None @@ -2034,7 +2034,7 @@ def build(self, input_shape): initializer="uniform", name="constant_kernel", ) - self.built = True + super().build(input_shape) def call(self, inputs, states, constants): [prev_output] = states @@ -2081,7 +2081,7 @@ def build(self, input_shape): self.bias = self.add_weight( shape=(self.unit_a, self.unit_b), initializer="uniform", name="bias" ) - self.built = True + super().build(input_shape) def call(self, inputs, states): prev_output = states[0] diff --git a/tf_keras/layers/rnn/base_wrapper.py b/tf_keras/layers/rnn/base_wrapper.py index 804bfad4f..a6fe01e7f 100644 --- a/tf_keras/layers/rnn/base_wrapper.py +++ b/tf_keras/layers/rnn/base_wrapper.py @@ -56,7 +56,7 @@ def build(self, input_shape=None): if not self.layer.built: self.layer.build(input_shape) self.layer.built = True - self.built = True + super().build(input_shape) @property def activity_regularizer(self): diff --git a/tf_keras/layers/rnn/bidirectional.py b/tf_keras/layers/rnn/bidirectional.py index 22c460881..5023404e6 100644 --- a/tf_keras/layers/rnn/bidirectional.py +++ b/tf_keras/layers/rnn/bidirectional.py @@ -470,7 +470,8 @@ def build(self, input_shape): self.forward_layer.build(input_shape) with backend.name_scope(self.backward_layer.name): self.backward_layer.build(input_shape) - self.built = True + # Call Layer.build() to skip Wrapper.build() which we override here. + Layer.build(self, input_shape) def compute_mask(self, inputs, mask): if isinstance(mask, list): diff --git a/tf_keras/layers/rnn/bidirectional_test.py b/tf_keras/layers/rnn/bidirectional_test.py index f7ecef224..8eab3bda4 100644 --- a/tf_keras/layers/rnn/bidirectional_test.py +++ b/tf_keras/layers/rnn/bidirectional_test.py @@ -60,7 +60,7 @@ def build(self, input_shape): initializer="uniform", name="constant_kernel", ) - self.built = True + super().build(input_shape) def call(self, inputs, states, constants): [prev_output] = states diff --git a/tf_keras/layers/rnn/cell_wrappers.py b/tf_keras/layers/rnn/cell_wrappers.py index 2754f57e8..de02f0704 100644 --- a/tf_keras/layers/rnn/cell_wrappers.py +++ b/tf_keras/layers/rnn/cell_wrappers.py @@ -102,10 +102,10 @@ def call(self, inputs, state, **kwargs): inputs, state, cell_call_fn=self.cell.call, **kwargs ) - def build(self, inputs_shape): + def build(self, input_shape): """Builds the wrapped cell.""" - self.cell.build(inputs_shape) - self.built = True + self.cell.build(input_shape) + super().build(input_shape) @property def wrapped_cell(self): diff --git a/tf_keras/layers/rnn/cudnn_gru.py b/tf_keras/layers/rnn/cudnn_gru.py index c98ece8b1..c6f9510a1 100644 --- a/tf_keras/layers/rnn/cudnn_gru.py +++ b/tf_keras/layers/rnn/cudnn_gru.py @@ -144,8 +144,6 @@ def build(self, input_shape): constraint=self.bias_constraint, ) - self.built = True - def _process_batch(self, inputs, initial_state): if not self.time_major: inputs = tf.transpose(inputs, perm=(1, 0, 2)) diff --git a/tf_keras/layers/rnn/cudnn_lstm.py b/tf_keras/layers/rnn/cudnn_lstm.py index 2440c352b..804fa67cd 100644 --- a/tf_keras/layers/rnn/cudnn_lstm.py +++ b/tf_keras/layers/rnn/cudnn_lstm.py @@ -170,8 +170,6 @@ def bias_initializer(_, *args, **kwargs): constraint=self.bias_constraint, ) - self.built = True - def _process_batch(self, inputs, initial_state): if not self.time_major: inputs = tf.transpose(inputs, perm=(1, 0, 2)) diff --git a/tf_keras/layers/rnn/gru.py b/tf_keras/layers/rnn/gru.py index dd2443be5..993740084 100644 --- a/tf_keras/layers/rnn/gru.py +++ b/tf_keras/layers/rnn/gru.py @@ -222,7 +222,6 @@ def build(self, input_shape): ) else: self.bias = None - self.built = True def call(self, inputs, states, training=None): h_tm1 = ( diff --git a/tf_keras/layers/rnn/legacy_cell_wrappers.py b/tf_keras/layers/rnn/legacy_cell_wrappers.py index a8a007057..a46288277 100644 --- a/tf_keras/layers/rnn/legacy_cell_wrappers.py +++ b/tf_keras/layers/rnn/legacy_cell_wrappers.py @@ -368,9 +368,9 @@ def _gen_seed(self, salt_prefix, index): def wrapped_cell(self): return self.cell - def build(self, inputs_shape): - self.cell.build(inputs_shape) - self.built = True + def build(self, input_shape): + self.cell.build(input_shape) + super().build(input_shape) def _variational_recurrent_dropout_value( self, unused_index, value, noise, keep_prob diff --git a/tf_keras/layers/rnn/legacy_cells.py b/tf_keras/layers/rnn/legacy_cells.py index d5053c8b4..d9414478e 100644 --- a/tf_keras/layers/rnn/legacy_cells.py +++ b/tf_keras/layers/rnn/legacy_cells.py @@ -246,11 +246,6 @@ def output_size(self): """Integer or TensorShape: size of outputs produced by this cell.""" raise NotImplementedError("Abstract method") - def build(self, _): - # This tells the parent Layer object that it's OK to call - # self.add_weight() inside the call() method. - pass - def get_initial_state(self, inputs=None, batch_size=None, dtype=None): if inputs is not None: # Validate the given batch_size and dtype against inputs if @@ -445,15 +440,15 @@ def output_size(self): return self._num_units @tf_utils.shape_type_conversion - def build(self, inputs_shape): - if inputs_shape[-1] is None: + def build(self, input_shape): + if input_shape[-1] is None: raise ValueError( "Expected inputs.shape[-1] to be known, " - f"received shape: {inputs_shape}" + f"received shape: {input_shape}" ) _check_supported_dtypes(self.dtype) - input_depth = inputs_shape[-1] + input_depth = input_shape[-1] self._kernel = self.add_weight( _WEIGHTS_VARIABLE_NAME, shape=[input_depth + self._num_units, self._num_units], @@ -464,7 +459,7 @@ def build(self, inputs_shape): initializer=tf.compat.v1.zeros_initializer(dtype=self.dtype), ) - self.built = True + super().build(input_shape) def call(self, inputs, state): """Most basic RNN: output = new_state = act(W * input + U * state + @@ -563,14 +558,14 @@ def output_size(self): return self._num_units @tf_utils.shape_type_conversion - def build(self, inputs_shape): - if inputs_shape[-1] is None: + def build(self, input_shape): + if input_shape[-1] is None: raise ValueError( "Expected inputs.shape[-1] to be known, " - f"received shape: {inputs_shape}" + f"received shape: {input_shape}" ) _check_supported_dtypes(self.dtype) - input_depth = inputs_shape[-1] + input_depth = input_shape[-1] self._gate_kernel = self.add_weight( f"gates/{_WEIGHTS_VARIABLE_NAME}", shape=[input_depth + self._num_units, 2 * self._num_units], @@ -600,7 +595,7 @@ def build(self, inputs_shape): ), ) - self.built = True + super().build(input_shape) def call(self, inputs, state): """Gated recurrent unit (GRU) with nunits cells.""" @@ -774,14 +769,14 @@ def output_size(self): return self._num_units @tf_utils.shape_type_conversion - def build(self, inputs_shape): - if inputs_shape[-1] is None: + def build(self, input_shape): + if input_shape[-1] is None: raise ValueError( "Expected inputs.shape[-1] to be known, " - f"received shape: {inputs_shape}" + f"received shape: {input_shape}" ) _check_supported_dtypes(self.dtype) - input_depth = inputs_shape[-1] + input_depth = input_shape[-1] h_depth = self._num_units self._kernel = self.add_weight( _WEIGHTS_VARIABLE_NAME, @@ -793,7 +788,7 @@ def build(self, inputs_shape): initializer=tf.compat.v1.zeros_initializer(dtype=self.dtype), ) - self.built = True + super().build(input_shape) def call(self, inputs, state): """Long short-term memory cell (LSTM). @@ -1017,14 +1012,14 @@ def output_size(self): return self._output_size @tf_utils.shape_type_conversion - def build(self, inputs_shape): - if inputs_shape[-1] is None: + def build(self, input_shape): + if input_shape[-1] is None: raise ValueError( "Expected inputs.shape[-1] to be known, " - f"received shape: {inputs_shape}" + f"received shape: {input_shape}" ) _check_supported_dtypes(self.dtype) - input_depth = inputs_shape[-1] + input_depth = input_shape[-1] h_depth = self._num_units if self._num_proj is None else self._num_proj maybe_partitioner = ( tf.compat.v1.fixed_size_partitioner(self._num_unit_shards) @@ -1076,7 +1071,7 @@ def build(self, inputs_shape): partitioner=maybe_proj_partitioner, ) - self.built = True + super().build(input_shape) def call(self, inputs, state): """Run one step of LSTM. diff --git a/tf_keras/layers/rnn/lstm.py b/tf_keras/layers/rnn/lstm.py index 8bfa5676f..dd749444d 100644 --- a/tf_keras/layers/rnn/lstm.py +++ b/tf_keras/layers/rnn/lstm.py @@ -236,7 +236,6 @@ def bias_initializer(_, *args, **kwargs): ) else: self.bias = None - self.built = True def _compute_carry_and_output(self, x, h_tm1, c_tm1): """Computes carry and output using split kernels.""" diff --git a/tf_keras/layers/rnn/simple_rnn.py b/tf_keras/layers/rnn/simple_rnn.py index 57a8ae8ee..3edb6b91c 100644 --- a/tf_keras/layers/rnn/simple_rnn.py +++ b/tf_keras/layers/rnn/simple_rnn.py @@ -189,7 +189,6 @@ def build(self, input_shape): ) else: self.bias = None - self.built = True def call(self, inputs, states, training=None): prev_output = states[0] if tf.nest.is_nested(states) else states diff --git a/tf_keras/layers/rnn/stacked_rnn_cells.py b/tf_keras/layers/rnn/stacked_rnn_cells.py index c824dbc55..acfdb0cda 100644 --- a/tf_keras/layers/rnn/stacked_rnn_cells.py +++ b/tf_keras/layers/rnn/stacked_rnn_cells.py @@ -166,6 +166,7 @@ def call(self, inputs, states, constants=None, training=None, **kwargs): @tf_utils.shape_type_conversion def build(self, input_shape): + super().build(input_shape) if isinstance(input_shape, list): input_shape = input_shape[0] @@ -195,7 +196,6 @@ def get_batch_input_shape(batch_size, dim): input_shape = tuple( [batch_size] + tf.TensorShape(output_dim).as_list() ) - self.built = True def get_config(self): cells = [] diff --git a/tf_keras/layers/rnn/time_distributed.py b/tf_keras/layers/rnn/time_distributed.py index b807ccd12..9dccebfaa 100644 --- a/tf_keras/layers/rnn/time_distributed.py +++ b/tf_keras/layers/rnn/time_distributed.py @@ -135,7 +135,6 @@ def build(self, input_shape): ) child_input_shape = tf_utils.convert_shapes(child_input_shape) super().build(tuple(child_input_shape)) - self.built = True def compute_output_shape(self, input_shape): input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False) diff --git a/tf_keras/legacy_tf_layers/base_test.py b/tf_keras/legacy_tf_layers/base_test.py index fc969724b..184c4726f 100644 --- a/tf_keras/legacy_tf_layers/base_test.py +++ b/tf_keras/legacy_tf_layers/base_test.py @@ -505,7 +505,7 @@ def testNameScopeIsConsistentWithVariableScope(self): class MyLayer(base_tf_layers.Layer): def build(self, input_shape): self.my_var = self.add_weight("my_var", (), tf.float32) - self.built = True + super().build(input_shape) def call(self, inputs): return tf.multiply(inputs, self.my_var, name="my_op") @@ -559,7 +559,7 @@ def testVariablesAreLiftedFromFunctionBuildingGraphs(self): class MyLayer(base_tf_layers.Layer): def build(self, input_shape): self.my_var = self.add_weight("my_var", (), tf.float32) - self.built = True + super().build(input_shape) def call(self, inputs): return inputs @@ -587,7 +587,7 @@ def build(self, input_shape): self.add_update( tf.compat.v1.assign_add(self.a, 1.0, name="b_update") ) - self.built = True + super().build(input_shape) def call(self, inputs): self.add_update( @@ -629,7 +629,7 @@ def build(self, input_shape): self.a = self.add_weight("a", (), tf.float32, trainable=False) self.b = self.add_weight("b", (), tf.float32, trainable=False) self.add_loss(self.a) - self.built = True + super().build(input_shape) def call(self, inputs): self.add_loss(inputs, inputs=True) diff --git a/tf_keras/mixed_precision/test_util.py b/tf_keras/mixed_precision/test_util.py index 22cd085e4..ecc4bca5b 100644 --- a/tf_keras/mixed_precision/test_util.py +++ b/tf_keras/mixed_precision/test_util.py @@ -171,14 +171,14 @@ def __init__( activity_regularizer=self._activity_regularizer, **kwargs ) - def build(self, _): + def build(self, input_shape): self.v = self.add_weight( self._var_name, (), initializer="ones", regularizer=self._regularizer, ) - self.built = True + super().build(input_shape) def call(self, inputs): self.assert_input_types(inputs) @@ -205,7 +205,7 @@ def get_config(self): class MultiplyLayerWithoutAutoCast(MultiplyLayer): """Same as MultiplyLayer, but does not use AutoCastVariables.""" - def build(self, _): + def build(self, input_shape): dtype = self.dtype if dtype in ("float16", "bfloat16"): dtype = "float32" @@ -217,7 +217,8 @@ def build(self, _): autocast=False, regularizer=self._regularizer, ) - self.built = True + # Call Layer.build() to skip MultiplyLayer.build() which we override. + base_layer.Layer.build(self, input_shape) def call(self, inputs): self.assert_input_types(inputs) diff --git a/tf_keras/premade_models/linear.py b/tf_keras/premade_models/linear.py index ce48655c2..99e0a509c 100644 --- a/tf_keras/premade_models/linear.py +++ b/tf_keras/premade_models/linear.py @@ -156,7 +156,8 @@ def build(self, input_shape): ) else: self.bias = None - self.built = True + # Call Layer.build() to skip Model.build() which we override here. + base_layer.Layer.build(self, input_shape) def call(self, inputs): result = None diff --git a/tf_keras/saving/legacy/saved_model/saved_model_test.py b/tf_keras/saving/legacy/saved_model/saved_model_test.py index 65c0ff31c..91c1b18c8 100644 --- a/tf_keras/saving/legacy/saved_model/saved_model_test.py +++ b/tf_keras/saving/legacy/saved_model/saved_model_test.py @@ -53,7 +53,7 @@ def build(self, input_shape): self.input_spec = keras.layers.InputSpec( shape=[None] * len(input_shape) ) - self.built = True + super().build(input_shape) def call(self, x, training=None): if training is None: