diff --git a/tf_keras/BUILD b/tf_keras/BUILD index 5e4345c5c..b3d6f8980 100644 --- a/tf_keras/BUILD +++ b/tf_keras/BUILD @@ -1,8 +1,8 @@ # Description: # Contains the TF-Keras API (internal TensorFlow version). -# Placeholder: load unaliased py_library load("@org_keras//tf_keras:tf_keras.bzl", "tf_py_test") +# Placeholder: load unaliased py_library # copybara:uncomment_begin(google-only) # load("//tools/build_defs/license:license.bzl", "license") diff --git a/tf_keras/benchmarks/BUILD b/tf_keras/benchmarks/BUILD index 6466aada5..105c29766 100644 --- a/tf_keras/benchmarks/BUILD +++ b/tf_keras/benchmarks/BUILD @@ -1,10 +1,10 @@ # Description: # Implementation of TF-Keras benchmarks. +load("@org_keras//tf_keras:tf_keras.bzl", "cuda_py_test") # Placeholder: load unaliased py_binary # Placeholder: load unaliased py_library # Placeholder: load unaliased py_test -load("@org_keras//tf_keras:tf_keras.bzl", "cuda_py_test") package( # copybara:uncomment default_applicable_licenses = ["//tf_keras:license"], diff --git a/tf_keras/benchmarks/layer_benchmarks/BUILD b/tf_keras/benchmarks/layer_benchmarks/BUILD index fab88a4e7..c38333a8c 100644 --- a/tf_keras/benchmarks/layer_benchmarks/BUILD +++ b/tf_keras/benchmarks/layer_benchmarks/BUILD @@ -1,8 +1,8 @@ # Description: # Implementation of benchmarks on TF-Keras layers. -# Placeholder: load unaliased py_library load("@org_keras//tf_keras:tf_keras.bzl", "tf_py_test") +# Placeholder: load unaliased py_library package( # copybara:uncomment default_applicable_licenses = ["//tf_keras:license"], diff --git a/tf_keras/benchmarks/saved_model_benchmarks/BUILD b/tf_keras/benchmarks/saved_model_benchmarks/BUILD index db3e3ecfe..2cf94e98c 100644 --- a/tf_keras/benchmarks/saved_model_benchmarks/BUILD +++ b/tf_keras/benchmarks/saved_model_benchmarks/BUILD @@ -1,8 +1,8 @@ # Description: # Implementation of TF-Keras benchmarks. -# Placeholder: load unaliased py_library load("@org_keras//tf_keras:tf_keras.bzl", "cuda_py_test") +# Placeholder: load unaliased py_library package( # copybara:uncomment default_applicable_licenses = ["//tf_keras:license"], diff --git a/tf_keras/callbacks.py b/tf_keras/callbacks.py index 54ff13bd7..013c41d02 100644 --- a/tf_keras/callbacks.py +++ b/tf_keras/callbacks.py @@ -1903,6 +1903,23 @@ def on_train_begin(self, logs=None): "only supports empty strategy, " "MirroredStrategy, MultiWorkerMirroredStrategy and TPUStrategy." ) + + # Re-initialize the optimizer. + if self.model.built: + if ( + self.model.optimizer is not None + and callable(getattr(self.model.optimizer, "build", None)) + and not getattr(self.model.optimizer, "_built", False) + ): + self.model.optimizer.build(self.model.trainable_variables) + else: + logging.warning( + "To use the BackupAndRestore callback, " + "you model must be built before you call `fit()`. " + f"Model {self.model} is unbuilt. You can build it " + "beforehand by calling it on a batch of data." + ) + self.model._training_state = worker_training_state.WorkerTrainingState( self.model, self.backup_dir, diff --git a/tf_keras/callbacks_test.py b/tf_keras/callbacks_test.py index c5c15500d..7df78fad1 100644 --- a/tf_keras/callbacks_test.py +++ b/tf_keras/callbacks_test.py @@ -471,11 +471,17 @@ class InterruptingCallback(keras.callbacks.Callback): def on_epoch_end(self, epoch, log=None): if epoch == epoch_int: + # Re-initialize optimizer to test state restore. + self.model.optimizer = sgd.SGD() + raise RuntimeError("EpochInterruption") def on_batch_end(self, batch, logs=None): self.batch_count += 1 if self.batch_count == steps_int: + # Re-initialize optimizer to test state restore. + self.model.optimizer = sgd.SGD() + raise RuntimeError("StepsInterruption") class VerifyRestore(Callback): @@ -505,6 +511,12 @@ def on_batch_begin(self, batch, logs=None): % (self.initial_epoch, self.initial_step) ) + def on_train_begin(self, logs=None): + if self.model.optimizer is None or not getattr( + self.model.optimizer, "_built", False + ): + raise ValueError("Optimizer did not restore at train begin") + model = keras.Sequential([keras.layers.Dense(10)]) optimizer = sgd.SGD() model.compile(optimizer, loss="mse") diff --git a/tf_keras/distribute/BUILD b/tf_keras/distribute/BUILD index 155c97109..2aafe2dc3 100644 --- a/tf_keras/distribute/BUILD +++ b/tf_keras/distribute/BUILD @@ -2,10 +2,10 @@ # keras/distribute package is intended to serve as the centralized place for things # related to dist-strat used by TF-Keras.. -# Placeholder: load unaliased py_library -load("@org_keras//tf_keras:tf_keras.bzl", "cuda_py_test") +load("@org_keras//tf_keras:tf_keras.bzl", "cuda_py_test") # buildifier: disable=same-origin-load load("@org_keras//tf_keras:tf_keras.bzl", "tf_py_test") # buildifier: disable=same-origin-load -load("@org_keras//tf_keras:tf_keras.bzl", "distribute_py_test") +load("@org_keras//tf_keras:tf_keras.bzl", "distribute_py_test") # buildifier: disable=same-origin-load +# Placeholder: load unaliased py_library package( # copybara:uncomment default_applicable_licenses = ["//tf_keras:license"], diff --git a/tf_keras/engine/BUILD b/tf_keras/engine/BUILD index 3611ad889..918e4bee0 100644 --- a/tf_keras/engine/BUILD +++ b/tf_keras/engine/BUILD @@ -1,14 +1,10 @@ # Description: # Contains the TF-Keras engine API (internal TensorFlow version). +load("@org_keras//tf_keras:tf_keras.bzl", "tf_py_test") # buildifier: disable=same-origin-load +load("@org_keras//tf_keras:tf_keras.bzl", "cuda_py_test") # buildifier: disable=same-origin-load # Placeholder: load unaliased py_library -# buildifier: disable=same-origin-load -load("@org_keras//tf_keras:tf_keras.bzl", "tf_py_test") - -# buildifier: disable=same-origin-load -load("@org_keras//tf_keras:tf_keras.bzl", "cuda_py_test") - package( # copybara:uncomment default_applicable_licenses = ["//tf_keras:license"], # TODO(scottzhu): Remove non-keras deps from TF. diff --git a/tf_keras/integration_test/BUILD b/tf_keras/integration_test/BUILD index d9b47df2b..756955120 100644 --- a/tf_keras/integration_test/BUILD +++ b/tf_keras/integration_test/BUILD @@ -1,11 +1,11 @@ # Description: # Contains TF-Keras integration tests that verify with other TF high level APIs. -# Placeholder: load unaliased py_library -load("@org_keras//tf_keras:tf_keras.bzl", "cuda_py_test") +load("@org_keras//tf_keras:tf_keras.bzl", "cuda_py_test") # buildifier: disable=same-origin-load load("@org_keras//tf_keras:tf_keras.bzl", "tf_py_test") # buildifier: disable=same-origin-load -load("@org_keras//tf_keras:tf_keras.bzl", "distribute_py_test") -load("@org_keras//tf_keras:tf_keras.bzl", "tpu_py_test") +load("@org_keras//tf_keras:tf_keras.bzl", "distribute_py_test") # buildifier: disable=same-origin-load +load("@org_keras//tf_keras:tf_keras.bzl", "tpu_py_test") # buildifier: disable=same-origin-load +# Placeholder: load unaliased py_library package( # copybara:uncomment default_applicable_licenses = ["//tf_keras:license"], diff --git a/tf_keras/layers/BUILD b/tf_keras/layers/BUILD index 1e50f4d43..e4883b3b9 100644 --- a/tf_keras/layers/BUILD +++ b/tf_keras/layers/BUILD @@ -1,8 +1,8 @@ # Description: # Contains the TF-Keras layers (internal TensorFlow version). -# Placeholder: load unaliased py_library load("@org_keras//tf_keras:tf_keras.bzl", "tf_py_test") +# Placeholder: load unaliased py_library package( # copybara:uncomment default_applicable_licenses = ["//tf_keras:license"], diff --git a/tf_keras/layers/convolutional/BUILD b/tf_keras/layers/convolutional/BUILD index b31a2f07b..c964a0d45 100644 --- a/tf_keras/layers/convolutional/BUILD +++ b/tf_keras/layers/convolutional/BUILD @@ -1,8 +1,8 @@ # Description: # Contains the TF-Keras convolution layers. -# Placeholder: load unaliased py_library load("@org_keras//tf_keras:tf_keras.bzl", "cuda_py_test") +# Placeholder: load unaliased py_library package( # copybara:uncomment default_applicable_licenses = ["//tf_keras:license"], diff --git a/tf_keras/layers/experimental/BUILD b/tf_keras/layers/experimental/BUILD index bb596a256..6d1c4bc16 100644 --- a/tf_keras/layers/experimental/BUILD +++ b/tf_keras/layers/experimental/BUILD @@ -3,8 +3,8 @@ # the training process. # Placeholder: load unaliased py_library -load("@org_keras//tf_keras:tf_keras.bzl", "tf_py_test") -load("@org_keras//tf_keras:tf_keras.bzl", "distribute_py_test") +load("@org_keras//tf_keras:tf_keras.bzl", "tf_py_test") # buildifier: disable=same-origin-load +load("@org_keras//tf_keras:tf_keras.bzl", "distribute_py_test") # buildifier: disable=same-origin-load package( # copybara:uncomment default_applicable_licenses = ["//third_party/py/keras:license"], diff --git a/tf_keras/layers/normalization/BUILD b/tf_keras/layers/normalization/BUILD index a465411d5..b4d51aa80 100644 --- a/tf_keras/layers/normalization/BUILD +++ b/tf_keras/layers/normalization/BUILD @@ -1,14 +1,10 @@ # Description: # Contains the TF-Keras normalization layers (internal TensorFlow version). +load("@org_keras//tf_keras:tf_keras.bzl", "cuda_py_test") # buildifier: disable=same-origin-load +load("@org_keras//tf_keras:tf_keras.bzl", "tf_py_test") # buildifier: disable=same-origin-load # Placeholder: load unaliased py_library -# buildifier: disable=same-origin-load -load("@org_keras//tf_keras:tf_keras.bzl", "cuda_py_test") - -# buildifier: disable=same-origin-load -load("@org_keras//tf_keras:tf_keras.bzl", "tf_py_test") - package( # copybara:uncomment default_applicable_licenses = ["//tf_keras:license"], # TODO(scottzhu): Remove non-keras deps from TF. diff --git a/tf_keras/layers/preprocessing/BUILD b/tf_keras/layers/preprocessing/BUILD index e80b66aed..5111c6b86 100644 --- a/tf_keras/layers/preprocessing/BUILD +++ b/tf_keras/layers/preprocessing/BUILD @@ -2,11 +2,9 @@ # Contains the TF-Keras preprocess layers (internal TensorFlow version). # Placeholder: load unaliased py_library -load("@org_keras//tf_keras:tf_keras.bzl", "tf_py_test") - -# buildifier: disable=same-origin-load -load("@org_keras//tf_keras:tf_keras.bzl", "cuda_py_test") -load("@org_keras//tf_keras:tf_keras.bzl", "distribute_py_test") +load("@org_keras//tf_keras:tf_keras.bzl", "tf_py_test") # buildifier: disable=same-origin-load +load("@org_keras//tf_keras:tf_keras.bzl", "cuda_py_test") # buildifier: disable=same-origin-load +load("@org_keras//tf_keras:tf_keras.bzl", "distribute_py_test") # buildifier: disable=same-origin-load package( # copybara:uncomment default_applicable_licenses = ["//tf_keras:license"], diff --git a/tf_keras/layers/preprocessing/benchmarks/BUILD b/tf_keras/layers/preprocessing/benchmarks/BUILD index a13af0c4d..babcb14d6 100644 --- a/tf_keras/layers/preprocessing/benchmarks/BUILD +++ b/tf_keras/layers/preprocessing/benchmarks/BUILD @@ -1,10 +1,7 @@ -# Placeholder: load unaliased py_library - # Benchmarks for TF-Keras preprocessing layers. -load("@org_keras//tf_keras:tf_keras.bzl", "cuda_py_test") - -# buildifier: disable=same-origin-load -load("@org_keras//tf_keras:tf_keras.bzl", "tf_py_test") +load("@org_keras//tf_keras:tf_keras.bzl", "cuda_py_test") # buildifier: disable=same-origin-load +load("@org_keras//tf_keras:tf_keras.bzl", "tf_py_test") # buildifier: disable=same-origin-load +# Placeholder: load unaliased py_library package( # copybara:uncomment default_applicable_licenses = ["//tf_keras:license"], diff --git a/tf_keras/layers/rnn/BUILD b/tf_keras/layers/rnn/BUILD index f26686395..a7a528c04 100644 --- a/tf_keras/layers/rnn/BUILD +++ b/tf_keras/layers/rnn/BUILD @@ -1,11 +1,9 @@ # Description: # Contains the TF-Keras recurrent layers. +load("@org_keras//tf_keras:tf_keras.bzl", "cuda_py_test") # buildifier: disable=same-origin-load +load("@org_keras//tf_keras:tf_keras.bzl", "tf_py_test") # buildifier: disable=same-origin-load # Placeholder: load unaliased py_library -load("@org_keras//tf_keras:tf_keras.bzl", "cuda_py_test") - -# buildifier: disable=same-origin-load -load("@org_keras//tf_keras:tf_keras.bzl", "tf_py_test") package( # copybara:uncomment default_applicable_licenses = ["//tf_keras:license"], diff --git a/tf_keras/layers/rnn/cudnn_gru.py b/tf_keras/layers/rnn/cudnn_gru.py index ea9fb56d5..c98ece8b1 100644 --- a/tf_keras/layers/rnn/cudnn_gru.py +++ b/tf_keras/layers/rnn/cudnn_gru.py @@ -172,6 +172,10 @@ def _process_batch(self, inputs, initial_state): shape=self._vector_shape, ) + batch_dim = tf.shape(inputs)[1] + max_sequence_length = tf.shape(inputs)[0] + sequence_lengths = tf.fill([batch_dim], max_sequence_length) + args = { "input": inputs, "input_h": input_h, @@ -179,9 +183,10 @@ def _process_batch(self, inputs, initial_state): "params": params, "is_training": True, "rnn_mode": "gru", + "sequence_lengths": sequence_lengths, } - outputs, h, _, _, _ = tf.raw_ops.CudnnRNNV2(**args) + outputs, h, _, _, _ = tf.raw_ops.CudnnRNNV3(**args) if self.stateful or self.return_state: h = h[0] diff --git a/tf_keras/layers/rnn/cudnn_lstm.py b/tf_keras/layers/rnn/cudnn_lstm.py index a1a6832c5..2440c352b 100644 --- a/tf_keras/layers/rnn/cudnn_lstm.py +++ b/tf_keras/layers/rnn/cudnn_lstm.py @@ -204,15 +204,20 @@ def _process_batch(self, inputs, initial_state): shape=self._vector_shape, ) + batch_dim = tf.shape(inputs)[1] + max_sequence_length = tf.shape(inputs)[0] + sequence_lengths = tf.fill([batch_dim], max_sequence_length) + args = { "input": inputs, "input_h": input_h, "input_c": input_c, "params": params, "is_training": True, + "sequence_lengths": sequence_lengths, } - outputs, h, c, _, _ = tf.raw_ops.CudnnRNNV2(**args) + outputs, h, c, _, _ = tf.raw_ops.CudnnRNNV3(**args) if self.stateful or self.return_state: h = h[0] diff --git a/tf_keras/layers/rnn/gru.py b/tf_keras/layers/rnn/gru.py index 32f6f4fcf..dd2443be5 100644 --- a/tf_keras/layers/rnn/gru.py +++ b/tf_keras/layers/rnn/gru.py @@ -1034,11 +1034,13 @@ def gpu_gru( mask, time_major ) - if not time_major and sequence_lengths is None: - inputs = tf.transpose(inputs, perm=(1, 0, 2)) - seq_axis, batch_axis = (0, 1) - else: - seq_axis, batch_axis = (0, 1) if time_major else (1, 0) + seq_axis, batch_axis = (0, 1) if time_major else (1, 0) + + if sequence_lengths is None: + max_sequence_length = tf.shape(inputs)[seq_axis] + batch_size = tf.shape(inputs)[batch_axis] + sequence_lengths = tf.fill([batch_size], max_sequence_length) + # For init_h, cuDNN expects one more dim of num_layers before or after batch # dim for time major or batch major inputs respectively init_h = tf.expand_dims(init_h, axis=seq_axis) @@ -1069,49 +1071,36 @@ def gpu_gru( transpose_weights=True, ) - if sequence_lengths is not None: - if go_backwards: - # Three reversals are required. E.g., - # normal input = [1, 2, 3, 0, 0] # where 0 need to be masked - # reversed_input_to_cudnn = [3, 2, 1, 0, 0] - # output_from_cudnn = [6, 5, 4, 0, 0] - # expected_output = [0, 0, 6, 5 ,4] - inputs = tf.reverse_sequence( - inputs, - sequence_lengths, - seq_axis=seq_axis, - batch_axis=batch_axis, - ) - outputs, h, _, _, _ = tf.raw_ops.CudnnRNNV3( - input=inputs, - input_h=init_h, - input_c=0, - params=params, - is_training=True, - rnn_mode="gru", - sequence_lengths=sequence_lengths, - time_major=time_major, + if go_backwards: + # Three reversals are required. E.g., + # normal input = [1, 2, 3, 0, 0] # where 0 need to be masked + # reversed_input_to_cudnn = [3, 2, 1, 0, 0] + # output_from_cudnn = [6, 5, 4, 0, 0] + # expected_output = [0, 0, 6, 5 ,4] + inputs = tf.reverse_sequence( + inputs, + sequence_lengths, + seq_axis=seq_axis, + batch_axis=batch_axis, ) - if go_backwards: - outputs = tf.reverse_sequence( - outputs, - sequence_lengths, - seq_axis=seq_axis, - batch_axis=batch_axis, - ) - outputs = tf.reverse(outputs, axis=[seq_axis]) - else: - if go_backwards: - # Reverse axis 0 since the input is already convert to time major. - inputs = tf.reverse(inputs, axis=[0]) - outputs, h, _, _ = tf.raw_ops.CudnnRNN( - input=inputs, - input_h=init_h, - input_c=0, - params=params, - is_training=True, - rnn_mode="gru", + outputs, h, _, _, _ = tf.raw_ops.CudnnRNNV3( + input=inputs, + input_h=init_h, + input_c=0, + params=params, + is_training=True, + rnn_mode="gru", + sequence_lengths=sequence_lengths, + time_major=time_major, + ) + if go_backwards: + outputs = tf.reverse_sequence( + outputs, + sequence_lengths, + seq_axis=seq_axis, + batch_axis=batch_axis, ) + outputs = tf.reverse(outputs, axis=[seq_axis]) last_output = outputs[-1] if not time_major and sequence_lengths is None and return_sequences: diff --git a/tf_keras/layers/rnn/lstm.py b/tf_keras/layers/rnn/lstm.py index 07800d93f..8bfa5676f 100644 --- a/tf_keras/layers/rnn/lstm.py +++ b/tf_keras/layers/rnn/lstm.py @@ -1063,11 +1063,13 @@ def gpu_lstm( mask, time_major ) - if not time_major and sequence_lengths is None: - inputs = tf.transpose(inputs, perm=(1, 0, 2)) - seq_axis, batch_axis = (0, 1) - else: - seq_axis, batch_axis = (0, 1) if time_major else (1, 0) + seq_axis, batch_axis = (0, 1) if time_major else (1, 0) + + if sequence_lengths is None: + max_sequence_length = tf.shape(inputs)[seq_axis] + batch_size = tf.shape(inputs)[batch_axis] + sequence_lengths = tf.fill([batch_size], max_sequence_length) + # For init_h and init_c, cuDNN expects one more dim of num_layers before or # after batch dim for time major or batch major inputs respectively init_h = tf.expand_dims(init_h, axis=seq_axis) @@ -1099,52 +1101,36 @@ def gpu_lstm( transpose_weights=True, ) - if sequence_lengths is not None: - if go_backwards: - # Three reversals are required. E.g., - # normal input = [1, 2, 3, 0, 0] # where 0 need to be masked - # reversed_input_to_cudnn = [3, 2, 1, 0, 0] - # output_from_cudnn = [6, 5, 4, 0, 0] - # expected_output = [0, 0, 6, 5 ,4] - inputs = tf.reverse_sequence( - inputs, - sequence_lengths, - seq_axis=seq_axis, - batch_axis=batch_axis, - ) - outputs, h, c, _, _ = tf.raw_ops.CudnnRNNV3( - input=inputs, - input_h=init_h, - input_c=init_c, - params=params, - is_training=True, - rnn_mode="lstm", - sequence_lengths=sequence_lengths, - time_major=time_major, + if go_backwards: + # Three reversals are required. E.g., + # normal input = [1, 2, 3, 0, 0] # where 0 need to be masked + # reversed_input_to_cudnn = [3, 2, 1, 0, 0] + # output_from_cudnn = [6, 5, 4, 0, 0] + # expected_output = [0, 0, 6, 5 ,4] + inputs = tf.reverse_sequence( + inputs, + sequence_lengths, + seq_axis=seq_axis, + batch_axis=batch_axis, ) - if go_backwards: - outputs = tf.reverse_sequence( - outputs, - sequence_lengths, - seq_axis=seq_axis, - batch_axis=batch_axis, - ) - outputs = tf.reverse(outputs, axis=[seq_axis]) - else: - # # Fill the array with shape [batch] with value of max timesteps. - # sequence_length = array_ops.fill([array_ops.shape(inputs)[1]], - # array_ops.shape(inputs)[0]) - if go_backwards: - # Reverse axis 0 since the input is already convert to time major. - inputs = tf.reverse(inputs, axis=[0]) - outputs, h, c, _ = tf.raw_ops.CudnnRNN( - input=inputs, - input_h=init_h, - input_c=init_c, - params=params, - is_training=True, - rnn_mode="lstm", + outputs, h, c, _, _ = tf.raw_ops.CudnnRNNV3( + input=inputs, + input_h=init_h, + input_c=init_c, + params=params, + is_training=True, + rnn_mode="lstm", + sequence_lengths=sequence_lengths, + time_major=time_major, + ) + if go_backwards: + outputs = tf.reverse_sequence( + outputs, + sequence_lengths, + seq_axis=seq_axis, + batch_axis=batch_axis, ) + outputs = tf.reverse(outputs, axis=[seq_axis]) last_output = outputs[-1] if not time_major and sequence_lengths is None and return_sequences: diff --git a/tf_keras/legacy_tf_layers/BUILD b/tf_keras/legacy_tf_layers/BUILD index c0ea17cdb..a23a17709 100644 --- a/tf_keras/legacy_tf_layers/BUILD +++ b/tf_keras/legacy_tf_layers/BUILD @@ -1,12 +1,10 @@ # Description: # Contains the legacy TF layers (internal TensorFlow version). +load("@org_keras//tf_keras:tf_keras.bzl", "tf_py_test") # buildifier: disable=same-origin-load +load("@org_keras//tf_keras:tf_keras.bzl", "cuda_py_test") # buildifier: disable=same-origin-load # Placeholder: load unaliased py_library -# buildifier: disable=same-origin-load -load("@org_keras//tf_keras:tf_keras.bzl", "tf_py_test") -load("@org_keras//tf_keras:tf_keras.bzl", "cuda_py_test") - package( # copybara:uncomment default_applicable_licenses = ["//tf_keras:license"], default_visibility = [ diff --git a/tf_keras/mixed_precision/BUILD b/tf_keras/mixed_precision/BUILD index 1d2b1dfa1..bc48de016 100644 --- a/tf_keras/mixed_precision/BUILD +++ b/tf_keras/mixed_precision/BUILD @@ -16,9 +16,9 @@ # Description: # Contains the TF-Keras Mixed Precision API (TensorFlow version). -# Placeholder: load unaliased py_library -load("@org_keras//tf_keras:tf_keras.bzl", "cuda_py_test") +load("@org_keras//tf_keras:tf_keras.bzl", "cuda_py_test") # buildifier: disable=same-origin-load load("@org_keras//tf_keras:tf_keras.bzl", "tf_py_test") # buildifier: disable=same-origin-load +# Placeholder: load unaliased py_library package( # copybara:uncomment default_applicable_licenses = ["//tf_keras:license"], diff --git a/tf_keras/models/BUILD b/tf_keras/models/BUILD index 3d065365c..af0d6a24d 100644 --- a/tf_keras/models/BUILD +++ b/tf_keras/models/BUILD @@ -1,8 +1,8 @@ # TF-Keras models # Placeholder: load unaliased py_library -load("@org_keras//tf_keras:tf_keras.bzl", "distribute_py_test") -load("@org_keras//tf_keras:tf_keras.bzl", "tf_py_test") +load("@org_keras//tf_keras:tf_keras.bzl", "tf_py_test") # buildifier: disable=same-origin-load +load("@org_keras//tf_keras:tf_keras.bzl", "distribute_py_test") # buildifier: disable=same-origin-load package( # copybara:uncomment default_applicable_licenses = ["//tf_keras:license"], diff --git a/tf_keras/optimizers/BUILD b/tf_keras/optimizers/BUILD index d294a820d..5c0d969ff 100644 --- a/tf_keras/optimizers/BUILD +++ b/tf_keras/optimizers/BUILD @@ -2,11 +2,9 @@ # Contains the TF-Keras Optimizer API. # Placeholder: load unaliased py_library -load("@org_keras//tf_keras:tf_keras.bzl", "cuda_py_test") - -# buildifier: disable=same-origin-load -load("@org_keras//tf_keras:tf_keras.bzl", "tf_py_test") -load("@org_keras//tf_keras:tf_keras.bzl", "distribute_py_test") +load("@org_keras//tf_keras:tf_keras.bzl", "cuda_py_test") # buildifier: disable=same-origin-load +load("@org_keras//tf_keras:tf_keras.bzl", "tf_py_test") # buildifier: disable=same-origin-load +load("@org_keras//tf_keras:tf_keras.bzl", "distribute_py_test") # buildifier: disable=same-origin-load package( # copybara:uncomment default_applicable_licenses = ["//tf_keras:license"], diff --git a/tf_keras/optimizers/legacy/BUILD b/tf_keras/optimizers/legacy/BUILD index 5825c70b0..af3dd69ef 100644 --- a/tf_keras/optimizers/legacy/BUILD +++ b/tf_keras/optimizers/legacy/BUILD @@ -1,8 +1,8 @@ # Description: # Contains the TF-Keras OptimizerV2 API (internal TensorFlow version). -# Placeholder: load unaliased py_library load("@org_keras//tf_keras:tf_keras.bzl", "cuda_py_test") +# Placeholder: load unaliased py_library package( # copybara:uncomment default_applicable_licenses = ["//tf_keras:license"], diff --git a/tf_keras/saving/BUILD b/tf_keras/saving/BUILD index b4303fd0e..912dcaf81 100644 --- a/tf_keras/saving/BUILD +++ b/tf_keras/saving/BUILD @@ -1,8 +1,8 @@ # Description: # Contains the TF-Keras save model API (internal TensorFlow version). -# Placeholder: load unaliased py_library load("@org_keras//tf_keras:tf_keras.bzl", "tf_py_test") +# Placeholder: load unaliased py_library package( # copybara:uncomment default_applicable_licenses = ["//tf_keras:license"], diff --git a/tf_keras/testing_infra/BUILD b/tf_keras/testing_infra/BUILD index 3a8d76342..012cbc03f 100644 --- a/tf_keras/testing_infra/BUILD +++ b/tf_keras/testing_infra/BUILD @@ -1,9 +1,9 @@ # Description: # Contains the TF-Keras testing infrastructure. +load("@org_keras//tf_keras:tf_keras.bzl", "tf_py_test") # Placeholder: load unaliased py_library # Placeholder: load unaliased py_test -load("@org_keras//tf_keras:tf_keras.bzl", "tf_py_test") package( # copybara:uncomment default_applicable_licenses = ["//tf_keras:license"], diff --git a/tf_keras/tests/BUILD b/tf_keras/tests/BUILD index 55e2e0734..321f46537 100644 --- a/tf_keras/tests/BUILD +++ b/tf_keras/tests/BUILD @@ -1,16 +1,12 @@ # Description: # Contains TF-Keras test utils and integration tests. +load("@org_keras//tf_keras:tf_keras.bzl", "cuda_py_test") # buildifier: disable=same-origin-load +load("@org_keras//tf_keras:tf_keras.bzl", "tf_py_test") # buildifier: disable=same-origin-load +load("@org_keras//tf_keras:tf_keras.bzl", "tpu_py_test") # buildifier: disable=same-origin-load # Placeholder: load unaliased py_library # Placeholder: load unaliased py_test -# buildifier: disable=same-origin-load -load("@org_keras//tf_keras:tf_keras.bzl", "cuda_py_test") - -# buildifier: disable=same-origin-load -load("@org_keras//tf_keras:tf_keras.bzl", "tf_py_test") -load("@org_keras//tf_keras:tf_keras.bzl", "tpu_py_test") - package( # copybara:uncomment default_applicable_licenses = ["//tf_keras:license"], default_visibility = [ diff --git a/tf_keras/utils/BUILD b/tf_keras/utils/BUILD index 27327e7f6..93fb50e4e 100644 --- a/tf_keras/utils/BUILD +++ b/tf_keras/utils/BUILD @@ -1,8 +1,8 @@ # Description: # Contains the TF-Keras Utilities (internal TensorFlow version). -# Placeholder: load unaliased py_library load("@org_keras//tf_keras:tf_keras.bzl", "tf_py_test") +# Placeholder: load unaliased py_library package( # copybara:uncomment default_applicable_licenses = ["//tf_keras:license"], diff --git a/tf_keras/utils/steps_per_execution_tuning.py b/tf_keras/utils/steps_per_execution_tuning.py index 3bebe4ecf..d106c2ba8 100644 --- a/tf_keras/utils/steps_per_execution_tuning.py +++ b/tf_keras/utils/steps_per_execution_tuning.py @@ -229,7 +229,7 @@ def _tune(self): if current_spe >= spe_limit: new_spe = current_spe - elif current_spe == 0: + elif current_spe <= 0: new_spe = self.init_spe self._steps_per_execution.assign(np.round(new_spe))