Skip to content

Commit

Permalink
Update the optimizer test wrt the new DVariable update.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 551607364
  • Loading branch information
qlzh727 authored and tensorflower-gardener committed Jul 27, 2023
1 parent ab566fd commit 21c25fd
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion keras/dtensor/optimizers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,11 @@ def test_aggregate_gradients_noop(self):
optimizer = adam.Adam(mesh=self.mesh)

variable_init_value = tf.ones(shape=(), dtype=tf.float32)
model_variable = dtensor.DVariable(variable_init_value, trainable=True)
model_variable = dtensor.DVariable(
variable_init_value,
trainable=True,
layout=dtensor.Layout.replicated(self.mesh, rank=0),
)
grads = tf.ones_like(variable_init_value)

grad_and_var = zip([grads], [model_variable])
Expand Down

0 comments on commit 21c25fd

Please sign in to comment.