diff --git a/imperative/python/megengine/distributed/helper.py b/imperative/python/megengine/distributed/helper.py index 5ad1b50f..f2d83ba0 100644 --- a/imperative/python/megengine/distributed/helper.py +++ b/imperative/python/megengine/distributed/helper.py @@ -186,9 +186,9 @@ def _get_device_count_worker(queue, device_type): queue.put(num) -def _check_device_initialized(device_type: str): +def _check_device_initialized(device_type: str, rank: int): try: - test = Tensor(1, device=device_type) + test = Tensor(1, device=(device_type + str(rank))) inited = False del test except: diff --git a/imperative/python/megengine/distributed/launcher.py b/imperative/python/megengine/distributed/launcher.py index b043705c..3e6d2b18 100644 --- a/imperative/python/megengine/distributed/launcher.py +++ b/imperative/python/megengine/distributed/launcher.py @@ -39,7 +39,7 @@ def _run_wrapped( machine_ranks: list, ): """Init distributed process group and run wrapped function.""" - _check_device_initialized(device_type) + _check_device_initialized(device_type, dev) init_process_group( master_ip=master_ip, port=port,