From 21c25fd38023a3783950c5577383ffe51a62f650 Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Thu, 27 Jul 2023 12:24:01 -0700 Subject: [PATCH] Update the optimizer test wrt the new DVariable update. PiperOrigin-RevId: 551607364 --- keras/dtensor/optimizers_test.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/keras/dtensor/optimizers_test.py b/keras/dtensor/optimizers_test.py index 80f74464aac..356d2d2965e 100644 --- a/keras/dtensor/optimizers_test.py +++ b/keras/dtensor/optimizers_test.py @@ -89,7 +89,11 @@ def test_aggregate_gradients_noop(self): optimizer = adam.Adam(mesh=self.mesh) variable_init_value = tf.ones(shape=(), dtype=tf.float32) - model_variable = dtensor.DVariable(variable_init_value, trainable=True) + model_variable = dtensor.DVariable( + variable_init_value, + trainable=True, + layout=dtensor.Layout.replicated(self.mesh, rank=0), + ) grads = tf.ones_like(variable_init_value) grad_and_var = zip([grads], [model_variable])