|
@@ -186,9 +186,9 @@ def _get_device_count_worker(queue, device_type): |
|
|
queue.put(num) |
|
|
queue.put(num) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _check_device_initialized(device_type: str): |
|
|
|
|
|
|
|
|
def _check_device_initialized(device_type: str, rank: int): |
|
|
try: |
|
|
try: |
|
|
test = Tensor(1, device=device_type) |
|
|
|
|
|
|
|
|
test = Tensor(1, device=(device_type + str(rank))) |
|
|
inited = False |
|
|
inited = False |
|
|
del test |
|
|
del test |
|
|
except: |
|
|
except: |
|
|