|
|
@@ -199,13 +199,14 @@ def test_param_pack_concat(): |
|
|
|
|
|
|
|
@pytest.mark.require_ngpu(2) |
|
|
|
@pytest.mark.parametrize("early_return", [False, True], ids=["common", "early_return"]) |
|
|
|
@pytest.mark.parametrize("output_size", [10, 10000], ids=["small_size", "large_size"]) |
|
|
|
@pytest.mark.isolated_distributed |
|
|
|
def test_collect_results(early_return): |
|
|
|
def test_collect_results(early_return, output_size): |
|
|
|
@dist.launcher |
|
|
|
def worker(): |
|
|
|
if early_return: |
|
|
|
exit(0) |
|
|
|
return (dist.get_rank(), dist.get_world_size()) |
|
|
|
return [dist.get_rank()] * output_size |
|
|
|
|
|
|
|
results = worker() |
|
|
|
world_size = len(results) |
|
|
@@ -213,6 +214,6 @@ def test_collect_results(early_return): |
|
|
|
expects = ( |
|
|
|
[None] * world_size |
|
|
|
if early_return |
|
|
|
else [(dev, world_size) for dev in range(world_size)] |
|
|
|
else [[dev] * output_size for dev in range(world_size)] |
|
|
|
) |
|
|
|
assert results == expects |