Skip to content

Commit

Permalink
Add more test coverage for the model integration test on all hardware…
Browse files Browse the repository at this point in the history
… types.

PiperOrigin-RevId: 551245005
  • Loading branch information
qlzh727 authored and tensorflower-gardener committed Jul 26, 2023
1 parent 8d5e9b2 commit bdfb8aa
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 89 deletions.
41 changes: 25 additions & 16 deletions keras/dtensor/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -117,22 +117,31 @@ tf_py_test(
],
)

tf_py_test(
name = "mnist_model_test",
srcs = ["mnist_model_test.py"],
shard_count = 2,
tags = [
"requires-net:external",
],
deps = [
":integration_test_utils",
":test_util",
"//:expect_numpy_installed",
"//:expect_tensorflow_installed",
"//keras/optimizers",
"//keras/utils:tf_utils",
],
)
# copybara:uncomment_begin(google-only)
# dtensor_test(
# name = "mnist_model_test",
# srcs = ["mnist_model_test.py"],
# env = {
# "CUDA_MODULE_LOADING": "LAZY",
# "TF_GPU_ALLOCATOR": "cuda_malloc_async",
# },
# tags = [
# "no_oss",
# "requires-net:external",
# ],
# deps = [
# ":dtensor",
# ":integration_test_utils",
# ":layout_map",
# ":test_util",
# "//keras:backend",
# "//keras/optimizers",
# "//keras/utils:tf_utils",
# "//:expect_numpy_installed",
# "//:expect_tensorflow_installed",
# ],
# )
# copybara:uncomment_end

tf_py_test(
name = "optimizers_test",
Expand Down
105 changes: 32 additions & 73 deletions keras/dtensor/mnist_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,76 +14,64 @@
# ==============================================================================
"""E2E Tests for mnist_model."""

import numpy as np
import tensorflow.compat.v2 as tf
from tensorflow.compat.v2.experimental import dtensor

from keras import backend
from keras.dtensor import dtensor_api as dtensor
from keras.dtensor import integration_test_utils
from keras.dtensor import layout_map as layout_map_lib
from keras.dtensor import test_util
from keras.optimizers import adam
from keras.utils import tf_utils


class MnistTest(test_util.DTensorBaseTest):
def test_mnist_training_cpu(self):
devices = tf.config.list_physical_devices("CPU")
tf.config.set_logical_device_configuration(
devices[0],
[
tf.config.LogicalDeviceConfiguration(),
]
* 8,
)

mesh = dtensor.create_mesh(
devices=["CPU:%d" % i for i in range(8)], mesh_dims=[("batch", 8)]
)

def setUp(self):
super().setUp()
backend.enable_tf_random_generator()
# Needed by keras initializers.
tf_utils.set_random_seed(1337)
global_ids = test_util.create_device_ids_array((2,))
local_device_ids = np.ravel(global_ids).tolist()
mesh_dict = {
device: tf.experimental.dtensor.Mesh(
["batch"],
global_ids,
local_device_ids,
test_util.create_device_list((2,), device),
)
for device in ("CPU", "GPU", "TPU")
}
self.mesh = self.configTestMesh(mesh_dict)

model = integration_test_utils.get_model_with_layout_map(
integration_test_utils.get_all_replicated_layout_map(mesh)
)
def test_mnist_training(self):
layout_map = layout_map_lib.LayoutMap(self.mesh)
with layout_map.scope():
model = integration_test_utils.get_model()

optimizer = adam.Adam(learning_rate=0.001, mesh=mesh)
optimizer = adam.Adam(learning_rate=0.001, mesh=self.mesh)
optimizer.build(model.trainable_variables)

train_losses = integration_test_utils.train_mnist_model_batch_sharded(
model,
optimizer,
mesh,
self.mesh,
num_epochs=3,
steps_per_epoch=100,
steps_per_epoch=20,
global_batch_size=64,
)
# Make sure the losses are decreasing
self.assertEqual(train_losses, sorted(train_losses, reverse=True))

def test_model_fit(self):
devices = tf.config.list_physical_devices("CPU")
tf.config.set_logical_device_configuration(
devices[0],
[
tf.config.LogicalDeviceConfiguration(),
]
* 8,
)
if self.mesh.device_type() == "GPU":
self.skipTest("TODO(b/292596476)")

mesh = dtensor.create_mesh(
devices=["CPU:%d" % i for i in range(8)], mesh_dims=[("batch", 8)]
)
layout_map = layout_map_lib.LayoutMap(self.mesh)
with layout_map.scope():
model = integration_test_utils.get_model()

backend.enable_tf_random_generator()
# Needed by keras initializers.
tf_utils.set_random_seed(1337)

model = integration_test_utils.get_model_with_layout_map(
integration_test_utils.get_all_replicated_layout_map(mesh)
)

optimizer = adam.Adam(learning_rate=0.001, mesh=mesh)
optimizer = adam.Adam(learning_rate=0.001, mesh=self.mesh)
optimizer.build(model.trainable_variables)

global_batch_size = 64
Expand All @@ -100,7 +88,7 @@ def distribute_ds(dataset):
def _create_batch_layout(tensor_spec):
rank = len(tensor_spec.shape) + 1
return dtensor.Layout.batch_sharded(
mesh, batch_dim="batch", rank=rank
self.mesh, batch_dim="batch", rank=rank
)

layouts = tf.nest.map_structure(
Expand All @@ -109,7 +97,7 @@ def _create_batch_layout(tensor_spec):

return dtensor.DTensorDataset(
dataset=dataset,
mesh=mesh,
mesh=self.mesh,
layouts=layouts,
global_batch_size=global_batch_size,
dataset_already_batched=False,
Expand All @@ -123,35 +111,6 @@ def _create_batch_layout(tensor_spec):
model.fit(train_ds, steps_per_epoch=10)
model.evaluate(eval_ds, steps=10)

def DISABLED_test_mnist_training_tpu(self):
# TODO(scottzhu): Enable TPU test once the dtensor_test rule is migrated
# out of learning/brain
dtensor.initialize_accelerator_system()
total_tpu_device_count = dtensor.num_global_devices("TPU")
mesh_shape = [total_tpu_device_count]
mesh = dtensor.create_tpu_mesh(["batch"], mesh_shape, "tpu_mesh")

# Needed by keras initializers.
tf_utils.set_random_seed(1337)

model = integration_test_utils.get_model_with_layout_map(
integration_test_utils.get_all_replicated_layout_map(mesh)
)

optimizer = adam.Adam(learning_rate=0.001, mesh=mesh)
optimizer.build(model.trainable_variables)

train_losses = integration_test_utils.train_mnist_model_batch_sharded(
model,
optimizer,
mesh,
num_epochs=3,
steps_per_epoch=100,
global_batch_size=64,
)
# Make sure the losses are decreasing
self.assertEqual(train_losses, sorted(train_losses, reverse=True))


if __name__ == "__main__":
tf.test.main()

0 comments on commit bdfb8aa

Please sign in to comment.