GitOrigin-RevId: 0f0dc001cf
release-1.5
@@ -181,11 +181,6 @@ def synchronized(func: Callable): | |||
return wrapper | |||
def _get_device_count_worker(queue, device_type): | |||
num = get_device_count(device_type) | |||
queue.put(num) | |||
def _check_device_initialized(device_type: str, rank: int): | |||
try: | |||
test = Tensor(1, device=(device_type + str(rank))) | |||
@@ -198,19 +193,6 @@ def _check_device_initialized(device_type: str, rank: int): | |||
raise RuntimeError(errmsg) | |||
def get_device_count_by_fork(device_type: str): | |||
""" | |||
Get device count in fork thread. | |||
See https://stackoverflow.com/questions/22950047/cuda-initialization-error-after-fork | |||
for more information. | |||
""" | |||
q = mp.Queue() | |||
p = mp.Process(target=_get_device_count_worker, args=(q, device_type)) | |||
p.start() | |||
p.join() | |||
return q.get() | |||
def bcast_list_(inps: list, group: Group = WORLD): | |||
""" | |||
Broadcast tensors between given group. | |||
@@ -13,9 +13,10 @@ import queue | |||
from .. import _exit | |||
from ..core._imperative_rt.core2 import full_sync | |||
from ..device import get_device_count | |||
from ..logger import get_logger | |||
from .group import _set_machine_ranks, group_barrier, init_process_group | |||
from .helper import _check_device_initialized, get_device_count_by_fork | |||
from .helper import _check_device_initialized | |||
from .server import Client, Server | |||
WARN_SUBPROCESS_EXIT_WITHOUT_RETURN = ( | |||
@@ -91,9 +92,7 @@ class launcher: | |||
backend="auto", | |||
): | |||
self.func = func | |||
self.n_gpus = ( | |||
n_gpus if n_gpus is not None else get_device_count_by_fork(device_type) | |||
) | |||
self.n_gpus = n_gpus if n_gpus is not None else get_device_count(device_type) | |||
self.world_size = world_size if world_size is not None else self.n_gpus | |||
self.rank_start = rank_start | |||
self.master_ip = master_ip | |||
@@ -1188,11 +1188,11 @@ def copy(inp, device=None): | |||
import numpy as np | |||
import platform | |||
from megengine import tensor | |||
from megengine.distributed.helper import get_device_count_by_fork | |||
from megengine.device import get_device_count | |||
import megengine.functional as F | |||
x = tensor([1, 2, 3], np.int32) | |||
if 1 == get_device_count_by_fork("gpu"): | |||
if 1 == get_device_count("gpu"): | |||
y = F.copy(x, "cpu1") | |||
print(y.numpy()) | |||
else: | |||
@@ -15,7 +15,7 @@ import megengine.functional | |||
import megengine.module | |||
from megengine import Parameter | |||
from megengine.core._imperative_rt.core2 import sync | |||
from megengine.distributed.helper import get_device_count_by_fork | |||
from megengine.device import get_device_count | |||
from megengine.experimental.autograd import ( | |||
disable_higher_order_directive, | |||
enable_higher_order_directive, | |||
@@ -25,7 +25,7 @@ from megengine.module import Linear, Module | |||
sys.path.append(os.path.join(os.path.dirname(__file__), "helpers")) | |||
_ngpu = get_device_count_by_fork("gpu") | |||
_ngpu = get_device_count("gpu") | |||
@pytest.fixture(autouse=True) | |||
@@ -16,7 +16,6 @@ import megengine.autodiff as ad | |||
import megengine.distributed as dist | |||
import megengine.optimizer as optimizer | |||
from megengine import Parameter, tensor | |||
from megengine.distributed.helper import get_device_count_by_fork | |||
from megengine.module import Module | |||
from megengine.optimizer import SGD | |||
@@ -18,7 +18,6 @@ import megengine.functional as F | |||
import megengine.module as M | |||
import megengine.optimizer as optim | |||
from megengine.autodiff import GradManager | |||
from megengine.distributed.helper import get_device_count_by_fork | |||
from megengine.jit import trace | |||
@@ -20,7 +20,6 @@ from megengine.core._imperative_rt import CompNode, TensorAttr, imperative | |||
from megengine.core._imperative_rt.core2 import TensorWeakRef, apply, sync | |||
from megengine.core.autodiff.grad import Grad | |||
from megengine.core.ops.builtin import Elemwise, Identity | |||
from megengine.distributed.helper import get_device_count_by_fork | |||
from megengine.functional.distributed import remote_recv, remote_send | |||
@@ -31,7 +31,7 @@ from megengine.core.tensor.dtype import ( | |||
quint4, | |||
quint8, | |||
) | |||
from megengine.distributed.helper import get_device_count_by_fork | |||
from megengine.device import get_device_count | |||
from megengine.tensor import Tensor | |||
@@ -184,8 +184,7 @@ def test_dtype_int4_ffi_handle(): | |||
@pytest.mark.skipif( | |||
get_device_count_by_fork("gpu") != 0, | |||
reason="TypeCvt to quint4 is not supported on GPU", | |||
get_device_count("gpu") != 0, reason="TypeCvt to quint4 is not supported on GPU", | |||
) | |||
def test_quint4_typecvt(): | |||
device = "xpux" | |||
@@ -17,11 +17,7 @@ import megengine as mge | |||
import megengine.distributed as dist | |||
from megengine.core.ops.builtin import CollectiveComm, ParamPackConcat, ParamPackSplit | |||
from megengine.device import get_default_device | |||
from megengine.distributed.helper import ( | |||
get_device_count_by_fork, | |||
param_pack_concat, | |||
param_pack_split, | |||
) | |||
from megengine.distributed.helper import param_pack_concat, param_pack_split | |||
def _assert_q_empty(q): | |||
@@ -22,8 +22,7 @@ from megengine import Parameter, Tensor, is_cuda_available, tensor | |||
from megengine.core._trace_option import use_symbolic_shape | |||
from megengine.core.autodiff.grad import Grad | |||
from megengine.core.tensor.utils import make_shape_tuple | |||
from megengine.distributed.helper import get_device_count_by_fork | |||
from megengine.jit import trace | |||
from megengine.device import get_device_count | |||
def test_where(): | |||
@@ -613,7 +612,7 @@ def test_nms(): | |||
@pytest.mark.skipif( | |||
get_device_count_by_fork("gpu") > 0, reason="cuda does not support nchw int8" | |||
get_device_count("gpu") > 0, reason="cuda does not support nchw int8" | |||
) | |||
def test_conv_bias(): | |||
inp_scale = 1.5 | |||
@@ -715,9 +714,7 @@ def test_conv_bias(): | |||
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "relu") | |||
@pytest.mark.skipif( | |||
get_device_count_by_fork("gpu") > 0, reason="no int8 algorithm on cuda" | |||
) | |||
@pytest.mark.skipif(get_device_count("gpu") > 0, reason="no int8 algorithm on cuda") | |||
def test_batch_conv_bias(): | |||
inp_scale = 1.5 | |||
w_scale = 2.5 | |||
@@ -16,7 +16,6 @@ import megengine.distributed as dist | |||
from megengine import Parameter, tensor | |||
from megengine.core._imperative_rt.core2 import sync | |||
from megengine.device import get_default_device, set_default_device | |||
from megengine.distributed.helper import get_device_count_by_fork | |||
from megengine.functional.distributed import ( | |||
all_gather, | |||
all_reduce_max, | |||
@@ -18,7 +18,6 @@ from megengine import tensor | |||
from megengine.core._trace_option import use_symbolic_shape | |||
from megengine.core.tensor import megbrain_graph as G | |||
from megengine.core.tensor.utils import astensor1d | |||
from megengine.distributed.helper import get_device_count_by_fork | |||
from megengine.jit import trace | |||
from megengine.utils.network import Network, set_symbolic_shape | |||
from megengine.utils.network_node import VarNode | |||
@@ -16,7 +16,6 @@ import megengine as mge | |||
import megengine.distributed as dist | |||
from megengine import Tensor | |||
from megengine.core._trace_option import use_symbolic_shape | |||
from megengine.distributed.helper import get_device_count_by_fork | |||
from megengine.module import BatchNorm1d, BatchNorm2d, SyncBatchNorm | |||
_assert_allclose = functools.partial(np.testing.assert_allclose, atol=5e-6, rtol=5e-6) | |||
@@ -6,7 +6,7 @@ import pytest | |||
import megengine.utils.comp_graph_tools as cgtools | |||
from megengine import jit, tensor | |||
from megengine.distributed.helper import get_device_count_by_fork | |||
from megengine.device import get_device_count | |||
from megengine.functional import expand_dims | |||
from megengine.module import ( | |||
BatchMatMulActivation, | |||
@@ -101,9 +101,7 @@ def test_qat_conv(): | |||
np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy()) | |||
@pytest.mark.skipif( | |||
get_device_count_by_fork("gpu") > 0, reason="no int8 algorithm on cuda" | |||
) | |||
@pytest.mark.skipif(get_device_count("gpu") > 0, reason="no int8 algorithm on cuda") | |||
def test_qat_batchmatmul_activation(): | |||
batch = 4 | |||
in_features = 8 | |||
@@ -13,7 +13,7 @@ import pytest | |||
import megengine as mge | |||
import megengine.distributed as dist | |||
from megengine.distributed.helper import get_device_count_by_fork | |||
from megengine.device import get_device_count | |||
from megengine.quantization import QuantMode, create_qparams | |||
from megengine.quantization.observer import ( | |||
ExponentialMovingAverageObserver, | |||
@@ -78,7 +78,7 @@ def test_passive_observer(): | |||
@pytest.mark.require_ngpu(2) | |||
@pytest.mark.isolated_distributed | |||
def test_sync_min_max_observer(): | |||
word_size = get_device_count_by_fork("gpu") | |||
word_size = get_device_count("gpu") | |||
x = np.random.rand(3 * word_size, 3, 3, 3).astype("float32") | |||
np_min, np_max = x.min(), x.max() | |||
@@ -96,7 +96,7 @@ def test_sync_min_max_observer(): | |||
@pytest.mark.require_ngpu(2) | |||
@pytest.mark.isolated_distributed | |||
def test_sync_exponential_moving_average_observer(): | |||
word_size = get_device_count_by_fork("gpu") | |||
word_size = get_device_count("gpu") | |||
t = np.random.rand() | |||
x1 = np.random.rand(3 * word_size, 3, 3, 3).astype("float32") | |||
x2 = np.random.rand(3 * word_size, 3, 3, 3).astype("float32") | |||
@@ -12,7 +12,7 @@ import pytest | |||
import megengine as mge | |||
import megengine.functional as F | |||
from megengine.core.tensor import dtype | |||
from megengine.distributed.helper import get_device_count_by_fork | |||
from megengine.device import get_device_count | |||
from megengine.functional.elemwise import _elemwise_multi_type, _elwise | |||
from megengine.quantization import QuantMode, create_qparams | |||
@@ -68,7 +68,7 @@ def test_elemwise(kind): | |||
@pytest.mark.skipif( | |||
get_device_count_by_fork("gpu") > 0, reason="cuda does not support nchw int8" | |||
get_device_count("gpu") > 0, reason="cuda does not support nchw int8" | |||
) | |||
def test_conv_bias(): | |||
inp_scale = np.float32(np.random.rand() + 1) | |||
@@ -26,12 +26,12 @@ from megengine.core.ops.builtin import ( | |||
PoissonRNG, | |||
UniformRNG, | |||
) | |||
from megengine.distributed.helper import get_device_count_by_fork | |||
from megengine.device import get_device_count | |||
from megengine.random import RNG, seed, uniform | |||
@pytest.mark.skipif( | |||
get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2", | |||
get_device_count("xpu") <= 2, reason="xpu counts need > 2", | |||
) | |||
def test_gaussian_op(): | |||
shape = ( | |||
@@ -61,7 +61,7 @@ def test_gaussian_op(): | |||
@pytest.mark.skipif( | |||
get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2", | |||
get_device_count("xpu") <= 2, reason="xpu counts need > 2", | |||
) | |||
def test_uniform_op(): | |||
shape = ( | |||
@@ -89,7 +89,7 @@ def test_uniform_op(): | |||
@pytest.mark.skipif( | |||
get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2", | |||
get_device_count("xpu") <= 2, reason="xpu counts need > 2", | |||
) | |||
def test_gamma_op(): | |||
_shape, _scale = 2, 0.8 | |||
@@ -117,7 +117,7 @@ def test_gamma_op(): | |||
@pytest.mark.skipif( | |||
get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2", | |||
get_device_count("xpu") <= 2, reason="xpu counts need > 2", | |||
) | |||
def test_beta_op(): | |||
_alpha, _beta = 2, 0.8 | |||
@@ -148,7 +148,7 @@ def test_beta_op(): | |||
@pytest.mark.skipif( | |||
get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2", | |||
get_device_count("xpu") <= 2, reason="xpu counts need > 2", | |||
) | |||
def test_poisson_op(): | |||
lam = F.full([8, 9, 11, 12], value=2, dtype="float32") | |||
@@ -171,7 +171,7 @@ def test_poisson_op(): | |||
@pytest.mark.skipif( | |||
get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2", | |||
get_device_count("xpu") <= 2, reason="xpu counts need > 2", | |||
) | |||
def test_permutation_op(): | |||
n = 1000 | |||
@@ -205,7 +205,7 @@ def test_permutation_op(): | |||
@pytest.mark.skipif( | |||
get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1", | |||
get_device_count("xpu") <= 1, reason="xpu counts need > 1", | |||
) | |||
def test_UniformRNG(): | |||
m1 = RNG(seed=111, device="xpu0") | |||
@@ -233,7 +233,7 @@ def test_UniformRNG(): | |||
@pytest.mark.skipif( | |||
get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1", | |||
get_device_count("xpu") <= 1, reason="xpu counts need > 1", | |||
) | |||
def test_NormalRNG(): | |||
m1 = RNG(seed=111, device="xpu0") | |||
@@ -262,7 +262,7 @@ def test_NormalRNG(): | |||
@pytest.mark.skipif( | |||
get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1", | |||
get_device_count("xpu") <= 1, reason="xpu counts need > 1", | |||
) | |||
def test_GammaRNG(): | |||
m1 = RNG(seed=111, device="xpu0") | |||
@@ -295,7 +295,7 @@ def test_GammaRNG(): | |||
@pytest.mark.skipif( | |||
get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1", | |||
get_device_count("xpu") <= 1, reason="xpu counts need > 1", | |||
) | |||
def test_BetaRNG(): | |||
m1 = RNG(seed=111, device="xpu0") | |||
@@ -330,7 +330,7 @@ def test_BetaRNG(): | |||
@pytest.mark.skipif( | |||
get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1", | |||
get_device_count("xpu") <= 1, reason="xpu counts need > 1", | |||
) | |||
def test_PoissonRNG(): | |||
m1 = RNG(seed=111, device="xpu0") | |||
@@ -359,7 +359,7 @@ def test_PoissonRNG(): | |||
@pytest.mark.skipif( | |||
get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1", | |||
get_device_count("xpu") <= 1, reason="xpu counts need > 1", | |||
) | |||
def test_PermutationRNG(): | |||
m1 = RNG(seed=111, device="xpu0") | |||
@@ -13,8 +13,7 @@ import megengine.random as rand | |||
from megengine.core._imperative_rt.core2 import apply | |||
from megengine.core._wrap import Device | |||
from megengine.core.ops import builtin | |||
from megengine.device import is_cuda_available | |||
from megengine.distributed.helper import get_device_count_by_fork | |||
from megengine.device import get_device_count, is_cuda_available | |||
from megengine.functional.external import tensorrt_runtime_opr | |||
from megengine.jit.tracing import trace | |||
from megengine.tensor import Tensor | |||
@@ -273,7 +272,7 @@ def test_deformable_ps_roi_pooling(): | |||
@pytest.mark.skipif( | |||
get_device_count_by_fork("gpu") > 0, | |||
get_device_count("gpu") > 0, | |||
reason="does not support int8 when gpu compute capability less than 6.1", | |||
) | |||
def test_convbias(): | |||
@@ -27,8 +27,14 @@ using namespace mgb; | |||
#include <thread> | |||
#include <cuda.h> | |||
#include <cuda_runtime.h> | |||
#ifdef __unix__ | |||
#include <unistd.h> | |||
#include <sys/wait.h> | |||
#endif | |||
using CudaCompNodeImpl = CudaCompNode::CompNodeImpl; | |||
namespace { | |||
@@ -700,19 +706,90 @@ void CudaCompNode::EventImpl::do_device_wait_by(Impl* cn_impl) { | |||
/* ===================== CudaCompNode static methods ===================== */ | |||
namespace { | |||
#ifndef __unix__ | |||
CUresult get_device_count_forksafe(int* pcnt) { | |||
cuInit(0); | |||
return cuDeviceGetCount(pcnt); | |||
} | |||
#else | |||
struct RAIICloseFD : NonCopyableObj { | |||
int m_fd = -1; | |||
RAIICloseFD(int fd) : m_fd(fd) {} | |||
~RAIICloseFD() {close();} | |||
void close() { | |||
if (m_fd != -1) { | |||
::close(m_fd); | |||
m_fd = -1; | |||
} | |||
} | |||
}; | |||
// an implementation that does not call cuInit | |||
CUresult get_device_count_forksafe(int* pcnt) { | |||
auto err = cuDeviceGetCount(pcnt); | |||
if (err != CUDA_ERROR_NOT_INITIALIZED) return err; | |||
// cuInit not called, call it in child process | |||
int fd[2]; | |||
mgb_assert(pipe(fd) == 0, "pipe() failed"); | |||
int fdr = fd[0], fdw = fd[1]; | |||
RAIICloseFD fdr_guard(fdr); | |||
RAIICloseFD fdw_guard(fdw); | |||
auto cpid = fork(); | |||
mgb_assert(cpid != -1, "fork() failed"); | |||
if (cpid == 0) { | |||
fdr_guard.close(); | |||
do { | |||
err = cuInit(0); | |||
if (err != CUDA_SUCCESS) break; | |||
err = cuDeviceGetCount(pcnt); | |||
} while (0); | |||
auto sz = write(fdw, &err, sizeof(err)); | |||
if (sz == sizeof(err) && err == CUDA_SUCCESS) { | |||
sz = write(fdw, pcnt, sizeof(*pcnt)); | |||
} | |||
fdw_guard.close(); | |||
std::quick_exit(0); | |||
} | |||
fdw_guard.close(); | |||
auto sz = read(fdr, &err, sizeof(err)); | |||
mgb_assert(sz == sizeof(err), "failed to read error code from child"); | |||
if (err == CUDA_SUCCESS) { | |||
sz = read(fdr, pcnt, sizeof(*pcnt)); | |||
mgb_assert(sz == sizeof(*pcnt), "failed to read device count from child"); | |||
return err; | |||
} | |||
// try again, maybe another thread called cuInit while we fork | |||
auto err2 = cuDeviceGetCount(pcnt); | |||
if (err2 == CUDA_SUCCESS) return err2; | |||
if (err2 == CUDA_ERROR_NOT_INITIALIZED) return err; | |||
return err2; | |||
} | |||
#endif | |||
const char* cu_get_error_string(CUresult err) { | |||
const char* ret = nullptr; | |||
cuGetErrorString(err, &ret); | |||
if (!ret) ret = "unknown cuda error"; | |||
return ret; | |||
} | |||
} // namespace | |||
bool CudaCompNode::available() { | |||
static int result = -1; | |||
static Spinlock mtx; | |||
MGB_LOCK_GUARD(mtx); | |||
if (result == -1) { | |||
int ndev = -1; | |||
auto err = cudaGetDeviceCount(&ndev); | |||
result = err == cudaSuccess && ndev > 0; | |||
auto err = get_device_count_forksafe(&ndev); | |||
result = err == CUDA_SUCCESS && ndev > 0; | |||
if (!result) { | |||
mgb_log_warn("cuda unavailable: %s(%d) ndev=%d", | |||
cudaGetErrorString(err), static_cast<int>(err), ndev); | |||
cu_get_error_string(err), static_cast<int>(err), ndev); | |||
} | |||
if (err == cudaErrorInitializationError) { | |||
if (err == CUDA_ERROR_NOT_INITIALIZED) { | |||
mgb_throw(std::runtime_error, "cuda initialization error."); | |||
} | |||
} | |||
@@ -857,11 +934,11 @@ size_t CudaCompNode::get_device_count(bool warn) { | |||
static Spinlock mtx; | |||
MGB_LOCK_GUARD(mtx); | |||
if (cnt == -1) { | |||
auto err = cudaGetDeviceCount(&cnt); | |||
if (err != cudaSuccess) { | |||
auto err = get_device_count_forksafe(&cnt); | |||
if (err != CUDA_SUCCESS) { | |||
if (warn) | |||
mgb_log_error("cudaGetDeviceCount failed: %s (err %d)", | |||
cudaGetErrorString(err), int(err)); | |||
cu_get_error_string(err), int(err)); | |||
cnt = 0; | |||
} | |||
mgb_assert(cnt >= 0); | |||