Skip to content

Commit

Permalink
Use file store for tests (#6632)
Browse files Browse the repository at this point in the history
This PR changes the `init_method` for tests to `FileStore` for
robustness.
  • Loading branch information
tohtana authored Oct 17, 2024
1 parent a36db9c commit c9fc34a
Showing 1 changed file with 31 additions and 19 deletions.
50 changes: 31 additions & 19 deletions tests/unit/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,16 +147,13 @@ class DistributedExec(ABC):
def run(self):
...

def __call__(self, request=None):
def __call__(self, request):
self._fixture_kwargs = self._get_fixture_kwargs(request, self.run)
world_size = self.world_size
if self.requires_cuda_env and not get_accelerator().is_available():
pytest.skip("only supported in accelerator environments.")

if isinstance(world_size, int):
world_size = [world_size]
for procs in world_size:
self._launch_procs(procs)
self._launch_with_file_store(request, world_size)

def _get_fixture_kwargs(self, request, func):
if not request:
Expand All @@ -172,7 +169,7 @@ def _get_fixture_kwargs(self, request, func):
pass # test methods can have kwargs that are not fixtures
return fixture_kwargs

def _launch_daemonic_procs(self, num_procs):
def _launch_daemonic_procs(self, num_procs, init_method):
# Create process pool or use cached one
master_port = None

Expand All @@ -198,7 +195,7 @@ def _launch_daemonic_procs(self, num_procs):
master_port = get_master_port()

# Run the test
args = [(local_rank, num_procs, master_port) for local_rank in range(num_procs)]
args = [(local_rank, num_procs, master_port, init_method) for local_rank in range(num_procs)]
skip_msgs_async = pool.starmap_async(self._dist_run, args)

try:
Expand All @@ -218,7 +215,7 @@ def _launch_daemonic_procs(self, num_procs):
assert len(set(skip_msgs)) == 1, "Multiple different skip messages received"
pytest.skip(skip_msgs[0])

def _launch_non_daemonic_procs(self, num_procs):
def _launch_non_daemonic_procs(self, num_procs, init_method):
assert not self.reuse_dist_env, "Cannot reuse distributed environment with non-daemonic processes"

master_port = get_master_port()
Expand All @@ -227,7 +224,7 @@ def _launch_non_daemonic_procs(self, num_procs):
prev_start_method = mp.get_start_method()
mp.set_start_method('spawn', force=True)
for local_rank in range(num_procs):
p = mp.Process(target=self._dist_run, args=(local_rank, num_procs, master_port, skip_msg))
p = mp.Process(target=self._dist_run, args=(local_rank, num_procs, master_port, init_method, skip_msg))
p.start()
processes.append(p)
mp.set_start_method(prev_start_method, force=True)
Expand Down Expand Up @@ -269,7 +266,7 @@ def _launch_non_daemonic_procs(self, num_procs):
# add a check here to assert all exit messages are equal
pytest.skip(skip_msg.get())

def _launch_procs(self, num_procs):
def _launch_procs(self, num_procs, init_method):
# Verify we have enough accelerator devices to run this test
if get_accelerator().is_available() and get_accelerator().device_count() < num_procs:
pytest.skip(
Expand All @@ -284,11 +281,11 @@ def _launch_procs(self, num_procs):
mp.set_start_method('forkserver', force=True)

if self.non_daemonic_procs:
self._launch_non_daemonic_procs(num_procs)
self._launch_non_daemonic_procs(num_procs, init_method)
else:
self._launch_daemonic_procs(num_procs)
self._launch_daemonic_procs(num_procs, init_method)

def _dist_run(self, local_rank, num_procs, master_port, skip_msg=""):
def _dist_run(self, local_rank, num_procs, master_port, init_method, skip_msg=""):
if not dist.is_initialized():
""" Initialize deepspeed.comm and execute the user function. """
if self.set_dist_env:
Expand All @@ -312,7 +309,10 @@ def _dist_run(self, local_rank, num_procs, master_port, skip_msg=""):
get_accelerator().set_device(local_rank)

if self.init_distributed:
deepspeed.init_distributed(dist_backend=self.backend)
deepspeed.init_distributed(dist_backend=self.backend,
init_method=init_method,
rank=local_rank,
world_size=num_procs)
dist.barrier()

try:
Expand All @@ -328,6 +328,22 @@ def _dist_run(self, local_rank, num_procs, master_port, skip_msg=""):

return skip_msg

def _launch_with_file_store(self, request, world_size):
tmpdir = request.getfixturevalue("tmpdir")
dist_file_store = tmpdir.join("dist_file_store")
assert not os.path.exists(dist_file_store)
init_method = f"file://{dist_file_store}"

if isinstance(world_size, int):
world_size = [world_size]
for procs in world_size:
try:
self._launch_procs(procs, init_method)
finally:
if os.path.exists(dist_file_store):
os.remove(dist_file_store)
time.sleep(0.5)

def _dist_destroy(self):
if (dist is not None) and dist.is_initialized():
dist.barrier()
Expand Down Expand Up @@ -473,11 +489,7 @@ def __call__(self, request):
else:
world_size = self._fixture_kwargs.get("world_size", self.world_size)

if isinstance(world_size, int):
world_size = [world_size]
for procs in world_size:
self._launch_procs(procs)
time.sleep(0.5)
self._launch_with_file_store(request, world_size)

def _get_current_test_func(self, request):
# DistributedTest subclasses may have multiple test methods
Expand Down

0 comments on commit c9fc34a

Please sign in to comment.