GitOrigin-RevId: 9abcc956ef
tags/v1.4.0-rc1
@@ -114,6 +114,7 @@ class launcher: | |||||
procs[dev].terminate() | procs[dev].terminate() | ||||
devs.clear() | devs.clear() | ||||
result_count = 0 | |||||
while len(devs) > 0: | while len(devs) > 0: | ||||
left = [] | left = [] | ||||
# check all processes in one second | # check all processes in one second | ||||
@@ -129,11 +130,21 @@ class launcher: | |||||
), "subprocess {} exit with code {}".format(dev + self.rank_start, code) | ), "subprocess {} exit with code {}".format(dev + self.rank_start, code) | ||||
if code == None: | if code == None: | ||||
left.append(dev) | left.append(dev) | ||||
elif queue.empty(): | |||||
get_logger().warning(WARN_SUBPROCESS_EXIT_WITHOUT_RETURN) | |||||
else: | |||||
# DO NOT delete it, multiprocess.Queue has small buffer | |||||
# fetch data early to avoid dead lock | |||||
if not queue.empty(): | |||||
result_count += 1 | |||||
dev, ret = queue.get_nowait() | dev, ret = queue.get_nowait() | ||||
results[dev] = ret | results[dev] = ret | ||||
devs = left | devs = left | ||||
while not queue.empty(): | |||||
result_count += 1 | |||||
dev, ret = queue.get_nowait() | |||||
results[dev] = ret | |||||
if result_count < self.n_gpus: | |||||
get_logger().warning(WARN_SUBPROCESS_EXIT_WITHOUT_RETURN) | |||||
return results | return results |
@@ -199,13 +199,14 @@ def test_param_pack_concat(): | |||||
@pytest.mark.require_ngpu(2) | @pytest.mark.require_ngpu(2) | ||||
@pytest.mark.parametrize("early_return", [False, True], ids=["common", "early_return"]) | @pytest.mark.parametrize("early_return", [False, True], ids=["common", "early_return"]) | ||||
@pytest.mark.parametrize("output_size", [10, 10000], ids=["small_size", "large_size"]) | |||||
@pytest.mark.isolated_distributed | @pytest.mark.isolated_distributed | ||||
def test_collect_results(early_return): | |||||
def test_collect_results(early_return, output_size): | |||||
@dist.launcher | @dist.launcher | ||||
def worker(): | def worker(): | ||||
if early_return: | if early_return: | ||||
exit(0) | exit(0) | ||||
return (dist.get_rank(), dist.get_world_size()) | |||||
return [dist.get_rank()] * output_size | |||||
results = worker() | results = worker() | ||||
world_size = len(results) | world_size = len(results) | ||||
@@ -213,6 +214,6 @@ def test_collect_results(early_return): | |||||
expects = ( | expects = ( | ||||
[None] * world_size | [None] * world_size | ||||
if early_return | if early_return | ||||
else [(dev, world_size) for dev in range(world_size)] | |||||
else [[dev] * output_size for dev in range(world_size)] | |||||
) | ) | ||||
assert results == expects | assert results == expects |