From 933dd9a4971fbcd06c632b003e193dd6a87fdeff Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 29 Jun 2021 16:37:01 +0800 Subject: [PATCH] feat(mge/distributed): add cuda env check before forked thread style(core/comp_node): reformat code GitOrigin-RevId: 372452a8eb9e84a2e82d466074f80f78d70531e8 --- imperative/python/megengine/distributed/helper.py | 12 + .../python/megengine/distributed/launcher.py | 3 +- .../unit/functional/test_functional_distributed.py | 13 + src/core/impl/comp_node/cuda/comp_node.cpp | 481 ++++++++++----------- 4 files changed, 256 insertions(+), 253 deletions(-) diff --git a/imperative/python/megengine/distributed/helper.py b/imperative/python/megengine/distributed/helper.py index c0743958..c91c7e01 100644 --- a/imperative/python/megengine/distributed/helper.py +++ b/imperative/python/megengine/distributed/helper.py @@ -165,6 +165,18 @@ def _get_device_count_worker(queue, device_type): queue.put(num) +def _check_device_initialized(device_type: str): + try: + test = Tensor(1, device=device_type) + inited = False + del test + except: + inited = True + errmsg = "The cuda env is set before the forked thread starts. Please do not use any cuda function or variable before forking." + if inited: + raise RuntimeError(errmsg) + + def get_device_count_by_fork(device_type: str): """ Get device count in fork thread. diff --git a/imperative/python/megengine/distributed/launcher.py b/imperative/python/megengine/distributed/launcher.py index 94a6b85a..2b78be40 100644 --- a/imperative/python/megengine/distributed/launcher.py +++ b/imperative/python/megengine/distributed/launcher.py @@ -15,7 +15,7 @@ from .. import _exit from ..core._imperative_rt.core2 import full_sync from ..logger import get_logger from .group import group_barrier, init_process_group -from .helper import get_device_count_by_fork +from .helper import _check_device_initialized, get_device_count_by_fork from .server import Client, Server WARN_SUBPROCESS_EXIT_WITHOUT_RETURN = ( @@ -37,6 +37,7 @@ def _run_wrapped( queue: mp.Queue, ): """Init distributed process group and run wrapped function.""" + _check_device_initialized(device_type) init_process_group( master_ip=master_ip, port=port, diff --git a/imperative/python/test/unit/functional/test_functional_distributed.py b/imperative/python/test/unit/functional/test_functional_distributed.py index d982ac62..0eaaaa05 100644 --- a/imperative/python/test/unit/functional/test_functional_distributed.py +++ b/imperative/python/test/unit/functional/test_functional_distributed.py @@ -246,3 +246,16 @@ def test_io_remote(shape): val = np.random.random_sample(shape).astype("float32") worker(val, shape) + + +@pytest.mark.require_ngpu(2) +def test_cuda_init_before_fork(): + a = mge.tensor(1, device="gpu0") + + @dist.launcher(n_gpus=2) + def worker(): + a += 1 + b = mge.tensor(2) + + with pytest.raises(AssertionError): + worker() diff --git a/src/core/impl/comp_node/cuda/comp_node.cpp b/src/core/impl/comp_node/cuda/comp_node.cpp index 64279cc4..a27c01cf 100644 --- a/src/core/impl/comp_node/cuda/comp_node.cpp +++ b/src/core/impl/comp_node/cuda/comp_node.cpp @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #include "./comp_node.h" @@ -21,8 +22,8 @@ using namespace mgb; #include "megbrain/comp_node/alloc.h" -#include #include +#include #include @@ -31,26 +32,23 @@ using namespace mgb; using CudaCompNodeImpl = CudaCompNode::CompNodeImpl; namespace { - size_t get_min_system_memory(size_t available) { - if (available < (1u << 31)) { - // 225MiB - return 225 * 1024 * 1024; - } else { - // max(300 MiB, 0.05 * available) - return std::max(300 * 1024 * 1024, available / 20); - } +size_t get_min_system_memory(size_t available) { + if (available < (1u << 31)) { + // 225MiB + return 225 * 1024 * 1024; + } else { + // max(300 MiB, 0.05 * available) + return std::max(300 * 1024 * 1024, available / 20); } - using CudaHostFunc = megdnn::thin_function; - void CUDART_CB cuda_host_func_caller(void* ud) { - mgb_assert(ud); - CudaHostFunc* func_ptr = reinterpret_cast(ud); - MGB_TRY { - (*func_ptr)(); - } MGB_FINALLY( - delete func_ptr; - ); - } -} // anonymous namespace +} +using CudaHostFunc = megdnn::thin_function; +void CUDART_CB cuda_host_func_caller(void* ud) { + mgb_assert(ud); + CudaHostFunc* func_ptr = reinterpret_cast(ud); + MGB_TRY { (*func_ptr)(); } + MGB_FINALLY(delete func_ptr;); +} +} // anonymous namespace namespace mgb { namespace mem_alloc { @@ -103,7 +101,8 @@ class CudaHostAllocator : public RawAllocator { public: void* alloc(size_t size) override { void* addr; - cudaError_t cuda_error = cudaHostAlloc(&addr, size, cudaHostAllocDefault); + cudaError_t cuda_error = + cudaHostAlloc(&addr, size, cudaHostAllocDefault); if (cuda_error == cudaSuccess) { mgb_assert(addr); return addr; @@ -162,7 +161,7 @@ std::unique_ptr DevMemAlloc::make_cuda_alloc() { } // namespace mgb /* ===================== CudaCompNodeImpl ===================== */ -class CudaCompNode::CompNodeImpl final: public CompNode::Impl { +class CudaCompNode::CompNodeImpl final : public CompNode::Impl { MGB_DYN_TYPE_OBJ_FINAL_DECL; friend class EventImpl; @@ -170,7 +169,7 @@ class CudaCompNode::CompNodeImpl final: public CompNode::Impl { struct DeviceInfo; struct StaticData; - static StaticData *sd; + static StaticData* sd; static Spinlock sd_mtx; #if !MGB_BUILD_SLIM_SERVING std::mutex m_update_mem; @@ -180,17 +179,15 @@ class CudaCompNode::CompNodeImpl final: public CompNode::Impl { //! failed bool m_initialized = false; Locator m_locator, m_locator_logical; - mem_alloc::StreamMemAlloc *m_mem_alloc; - DeviceInfo *m_device_info; + mem_alloc::StreamMemAlloc* m_mem_alloc; + DeviceInfo* m_device_info; std::unique_ptr m_sync_event; Spinlock m_sync_event_mtx; - void activate() { - m_env.cuda_env().activate(); - } + void activate() { m_env.cuda_env().activate(); } - void init(const Locator &locator, const Locator &locator_logical); + void init(const Locator& locator, const Locator& locator_logical); void fini(); //! return whether global finalized, and print warning in such case @@ -207,117 +204,111 @@ class CudaCompNode::CompNodeImpl final: public CompNode::Impl { static_cast(self)->free_host(ptr); } +public: + CompNodeImpl() : Impl(static_free_device, static_free_host) {} - public: - CompNodeImpl() : Impl(static_free_device, static_free_host) {} - - void* alloc_device(size_t size) override { - activate(); + void* alloc_device(size_t size) override { + activate(); #if MGB_BUILD_SLIM_SERVING - return m_mem_alloc->alloc(size); + return m_mem_alloc->alloc(size); #else - void* ptr = m_mem_alloc->alloc(size); - { - MGB_LOCK_GUARD(m_update_mem); - ptr2size[ptr] = size; - m_used_mem += size; - } - return ptr; -#endif + void* ptr = m_mem_alloc->alloc(size); + { + MGB_LOCK_GUARD(m_update_mem); + ptr2size[ptr] = size; + m_used_mem += size; } + return ptr; +#endif + } - void free_device(void *ptr); + void free_device(void* ptr); - void *alloc_host(size_t size) override; + void* alloc_host(size_t size) override; - void free_host(void *ptr); + void free_host(void* ptr); - void copy_to_host(void *host_ptr, - const void *device_ptr, size_t size) override { - activate(); - MGB_CUDA_CHECK(cudaMemcpyAsync(host_ptr, device_ptr, size, - cudaMemcpyDeviceToHost, m_env.cuda_env().stream)); - } + void copy_to_host(void* host_ptr, const void* device_ptr, + size_t size) override { + activate(); + MGB_CUDA_CHECK(cudaMemcpyAsync(host_ptr, device_ptr, size, + cudaMemcpyDeviceToHost, + m_env.cuda_env().stream)); + } - void copy_to_device(void *device_ptr, - const void *host_ptr, size_t size) override { - activate(); - MGB_CUDA_CHECK(cudaMemcpyAsync(device_ptr, host_ptr, size, - cudaMemcpyHostToDevice, m_env.cuda_env().stream)); - } + void copy_to_device(void* device_ptr, const void* host_ptr, + size_t size) override { + activate(); + MGB_CUDA_CHECK(cudaMemcpyAsync(device_ptr, host_ptr, size, + cudaMemcpyHostToDevice, + m_env.cuda_env().stream)); + } - void peer_copy_to( - Impl *dest_impl, void *dest, - const void *src, size_t size) override; + void peer_copy_to(Impl* dest_impl, void* dest, const void* src, + size_t size) override; - size_t get_mem_addr_alignment() override { - return m_env.property().mem_alignment; - } + size_t get_mem_addr_alignment() override { + return m_env.property().mem_alignment; + } - std::unique_ptr create_event(size_t flags) override; + std::unique_ptr create_event(size_t flags) override; - void sync() override; + void sync() override; - MemNode mem_node() override; + MemNode mem_node() override; - std::pair get_mem_status_bytes() override { - // explicitly call cuda_env() to ensure async init is finished - m_env.cuda_env().activate(); - size_t tot, free; - MGB_CUDA_CHECK(cudaMemGetInfo(&free, &tot)); - free += m_mem_alloc->get_free_memory_dev().tot; - return {tot, free}; - } + std::pair get_mem_status_bytes() override { + // explicitly call cuda_env() to ensure async init is finished + m_env.cuda_env().activate(); + size_t tot, free; + MGB_CUDA_CHECK(cudaMemGetInfo(&free, &tot)); + free += m_mem_alloc->get_free_memory_dev().tot; + return {tot, free}; + } #if !MGB_BUILD_SLIM_SERVING - std::pair get_free_left_and_right(size_t begin_ptr, size_t end_ptr) override { - return m_mem_alloc->get_free_left_and_right(begin_ptr, end_ptr); - } + std::pair get_free_left_and_right(size_t begin_ptr, + size_t end_ptr) override { + return m_mem_alloc->get_free_left_and_right(begin_ptr, end_ptr); + } #endif - Locator locator() override { - return m_locator; - } + Locator locator() override { return m_locator; } - Locator locator_logical() override { - return m_locator_logical; - } + Locator locator_logical() override { return m_locator_logical; } - void add_callback(CudaHostFunc&& cb) override { + void add_callback(CudaHostFunc&& cb) override { #if CUDART_VERSION >= 10000 - activate(); - CudaHostFunc* func_ptr = new CudaHostFunc(std::move(cb)); - MGB_TRY { - MGB_CUDA_CHECK(cudaLaunchHostFunc(m_env.cuda_env().stream, - cuda_host_func_caller, static_cast(func_ptr))); - } MGB_CATCH(..., { - delete func_ptr; - throw; - }); + activate(); + CudaHostFunc* func_ptr = new CudaHostFunc(std::move(cb)); + MGB_TRY { + MGB_CUDA_CHECK(cudaLaunchHostFunc(m_env.cuda_env().stream, + cuda_host_func_caller, + static_cast(func_ptr))); + } + MGB_CATCH(..., { + delete func_ptr; + throw; + }); #else - MGB_MARK_USED_VAR(cb); - MGB_MARK_USED_VAR(cuda_host_func_caller); - mgb_throw( - MegBrainError, - "add_callback only support in cuda10.0 and later version"); + MGB_MARK_USED_VAR(cb); + MGB_MARK_USED_VAR(cuda_host_func_caller); + mgb_throw(MegBrainError, + "add_callback only support in cuda10.0 and later version"); #endif - } + } - uint64_t get_uid() override { - return m_uid; - } + uint64_t get_uid() override { return m_uid; } #if !MGB_BUILD_SLIM_SERVING - size_t get_used_memory() override { - return m_used_mem; - } + size_t get_used_memory() override { return m_used_mem; } #endif - private: - uint64_t m_uid; +private: + uint64_t m_uid; #if !MGB_BUILD_SLIM_SERVING - std::unordered_map ptr2size; - size_t m_used_mem = 0; + std::unordered_map ptr2size; + size_t m_used_mem = 0; #endif }; MGB_DYN_TYPE_OBJ_FINAL_IMPL(CudaCompNode::CompNodeImpl); @@ -326,15 +317,11 @@ struct CudaCompNodeImpl::DeviceInfo { int dev_num = -1; std::unique_ptr mem_alloc; - bool init_done() const { - return mem_alloc.get(); - } + bool init_done() const { return mem_alloc.get(); } - void init(const CompNodeEnv &env); + void init(const CompNodeEnv& env); - void fini() { - mem_alloc.reset(); - } + void fini() { mem_alloc.reset(); } }; struct CudaCompNodeImpl::StaticData { @@ -347,21 +334,21 @@ struct CudaCompNodeImpl::StaticData { std::unique_ptr host_alloc; CudaCompNode::CompNodeImpl node[MAX_NR_COMP_NODE]; DeviceInfo dev_info[MAX_NR_DEVICE]; - int nr_node = 0, //!< number of loaded node[] - nr_dev_used = 0; //!< number of used dev_info[] + int nr_node = 0, //!< number of loaded node[] + nr_dev_used = 0; //!< number of used dev_info[] - StaticData() : host_alloc( - mem_alloc::SimpleCachingAlloc::make( - std::make_unique())) { + StaticData() + : host_alloc(mem_alloc::SimpleCachingAlloc::make( + std::make_unique())) { prealloc_config.max_overhead = 0; prealloc_config.alignment = 1; host_alloc->alignment(1); } ~StaticData() { - for (int i = 0; i < nr_node; ++ i) + for (int i = 0; i < nr_node; ++i) node[i].fini(); - for (int i = 0; i < nr_dev_used; ++ i) + for (int i = 0; i < nr_dev_used; ++i) dev_info[i].fini(); } @@ -382,21 +369,21 @@ struct CudaCompNodeImpl::StaticData { CudaCompNodeImpl::StaticData* CudaCompNodeImpl::sd = nullptr; Spinlock CudaCompNodeImpl::sd_mtx; -void CudaCompNodeImpl::init( - const Locator &locator, const Locator &locator_logical) { +void CudaCompNodeImpl::init(const Locator& locator, + const Locator& locator_logical) { m_locator = locator; m_locator_logical = locator_logical; m_initialized = true; #if defined(__linux__) || defined(TARGET_OS_MAC) - FILE *fp; + FILE* fp; fp = fopen("/dev/urandom", "r"); mgb_assert(fread(&m_uid, sizeof(m_uid), 1, fp) == 1); fclose(fp); #else m_uid = std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch() - ).count(); + std::chrono::system_clock::now().time_since_epoch()) + .count(); #endif auto on_succ = [this](cudaStream_t stream) { @@ -404,8 +391,8 @@ void CudaCompNodeImpl::init( log_comp_node_created(locator, m_locator_logical); MGB_LOCK_GUARD(sd->mtx); - DeviceInfo *dev_info = nullptr; - for (int i = 0; i < sd->nr_dev_used; ++ i) { + DeviceInfo* dev_info = nullptr; + for (int i = 0; i < sd->nr_dev_used; ++i) { if (sd->dev_info[i].dev_num == locator.device) { dev_info = &sd->dev_info[i]; break; @@ -416,7 +403,7 @@ void CudaCompNodeImpl::init( dev_info = &sd->dev_info[sd->nr_dev_used]; dev_info->init(m_env); // note: add nr_dev_used only after init succeeds - ++ sd->nr_dev_used; + ++sd->nr_dev_used; } m_device_info = dev_info; m_mem_alloc = @@ -428,9 +415,8 @@ void CudaCompNodeImpl::init( m_initialized = false; }; - m_env.init_cuda_async( - locator.device, make_comp_node_from_impl(this), - {on_succ, on_error}); + m_env.init_cuda_async(locator.device, make_comp_node_from_impl(this), + {on_succ, on_error}); } void CudaCompNodeImpl::fini() { @@ -444,7 +430,7 @@ void CudaCompNodeImpl::fini() { m_initialized = false; } -void CudaCompNodeImpl::free_device(void *ptr) { +void CudaCompNodeImpl::free_device(void* ptr) { if (check_global_finalized()) return; @@ -452,13 +438,13 @@ void CudaCompNodeImpl::free_device(void *ptr) { #if !MGB_BUILD_SLIM_SERVING { MGB_LOCK_GUARD(m_update_mem); - mgb_assert(ptr2size.find(ptr) != ptr2size.end(), "ptr %p not found!", ptr); + mgb_assert(ptr2size.find(ptr) != ptr2size.end(), "ptr %p not found!", + ptr); m_used_mem -= ptr2size.at(ptr); ptr2size.erase(ptr); } #endif m_mem_alloc->free(ptr); - } void* CudaCompNodeImpl::alloc_host(size_t size) { @@ -468,38 +454,37 @@ void* CudaCompNodeImpl::alloc_host(size_t size) { } void CudaCompNodeImpl::free_host(void* ptr) { - if (check_global_finalized()) return; + if (check_global_finalized()) + return; sd->host_alloc->free(ptr); } -void CudaCompNodeImpl::peer_copy_to( - Impl *dest_impl, void *dest, const void *src, size_t size) { +void CudaCompNodeImpl::peer_copy_to(Impl* dest_impl, void* dest, + const void* src, size_t size) { if (dest_impl->same_type()) { - auto &&dst_env = static_cast( - dest_impl)->m_env.cuda_env(); - auto &&src_env = m_env.cuda_env(); + auto&& dst_env = + static_cast(dest_impl)->m_env.cuda_env(); + auto&& src_env = m_env.cuda_env(); activate(); if (dst_env.device == src_env.device) { - MGB_CUDA_CHECK(cudaMemcpyAsync(dest, src, size, - cudaMemcpyDeviceToDevice, - dst_env.stream)); + MGB_CUDA_CHECK(cudaMemcpyAsync( + dest, src, size, cudaMemcpyDeviceToDevice, dst_env.stream)); } else { enable_peer_access(src_env.device, dst_env.device); enable_peer_access(dst_env.device, src_env.device); - MGB_CUDA_CHECK(cudaMemcpyPeerAsync( - dest, dst_env.device, - src, src_env.device, size, - dst_env.stream)); + MGB_CUDA_CHECK(cudaMemcpyPeerAsync(dest, dst_env.device, src, + src_env.device, size, + dst_env.stream)); } return; } mgb_assert(dest_impl->env().property().type == DeviceType::CPU, - "cuda peer_copy_to only implemented for CPU"); + "cuda peer_copy_to only implemented for CPU"); auto copy = [this, dest, src, size]() { auto stream = m_env.cuda_env().stream; m_env.cuda_env().activate(); - MGB_CUDA_CHECK(cudaMemcpyAsync( - dest, src, size, cudaMemcpyDeviceToHost, stream)); + MGB_CUDA_CHECK(cudaMemcpyAsync(dest, src, size, cudaMemcpyDeviceToHost, + stream)); MGB_CUDA_CHECK(cudaStreamSynchronize(stream)); }; dest_impl->env().cpu_env().dispatch(copy); @@ -514,13 +499,13 @@ MemNode CudaCompNodeImpl::mem_node() { void CudaCompNodeImpl::sync() { activate(); - // do not use MGB_CUDA_CHECK(cudaStreamSynchronize(m_env->stream)) since other - // threads may be adding operations into the stream, and we only care about - // previous operations in current thread. However docs of + // do not use MGB_CUDA_CHECK(cudaStreamSynchronize(m_env->stream)) since + // other threads may be adding operations into the stream, and we only care + // about previous operations in current thread. However docs of // cudaStreamSynchronize did not describe details of such condition, so we // use manual event implementation - Event *event; + Event* event; { MGB_LOCK_GUARD(m_sync_event_mtx); if (!m_sync_event) @@ -532,8 +517,8 @@ void CudaCompNodeImpl::sync() { } void CudaCompNodeImpl::enable_peer_access(int dev0, int dev1) { - static bool already_enabled[ - StaticData::MAX_NR_DEVICE][StaticData::MAX_NR_DEVICE]; + static bool already_enabled[StaticData::MAX_NR_DEVICE] + [StaticData::MAX_NR_DEVICE]; if (already_enabled[dev0][dev1]) return; @@ -550,7 +535,8 @@ void CudaCompNodeImpl::enable_peer_access(int dev0, int dev1) { auto err = cudaDeviceEnablePeerAccess(dev1, 0); if (err != cudaSuccess) { mgb_log_error("failed to enable peer access from %d to %d: %s(%d)", - dev0, dev1, cudaGetErrorString(err), static_cast(err)); + dev0, dev1, cudaGetErrorString(err), + static_cast(err)); cudaGetLastError(); } } @@ -563,33 +549,29 @@ void CudaCompNodeImpl::enable_peer_access(int dev0, int dev1) { MGB_CUDA_CHECK(cudaMalloc(&dp0, sizeof(int))); MGB_CUDA_CHECK(cudaSetDevice(dev1)); MGB_CUDA_CHECK(cudaMalloc(&dp1, sizeof(int))); - MGB_CUDA_CHECK(cudaMemcpy(dp0, &v0, sizeof(int), - cudaMemcpyHostToDevice)); - MGB_CUDA_CHECK(cudaMemcpy(dp1, &v1, sizeof(int), - cudaMemcpyHostToDevice)); + MGB_CUDA_CHECK(cudaMemcpy(dp0, &v0, sizeof(int), cudaMemcpyHostToDevice)); + MGB_CUDA_CHECK(cudaMemcpy(dp1, &v1, sizeof(int), cudaMemcpyHostToDevice)); MGB_CUDA_CHECK(cudaMemcpyPeer(dp1, dev1, dp0, dev0, sizeof(int))); int get = 0; - MGB_CUDA_CHECK(cudaMemcpy(&get, dp1, sizeof(int), - cudaMemcpyDeviceToHost)); + MGB_CUDA_CHECK(cudaMemcpy(&get, dp1, sizeof(int), cudaMemcpyDeviceToHost)); mgb_throw_if(get != 1, CudaError, - "P2P copy (%d => %d) check failed; consider disabling " - "Access Control Services(ACS) for the PCI device", - dev0, dev1); - + "P2P copy (%d => %d) check failed; consider disabling " + "Access Control Services(ACS) for the PCI device", + dev0, dev1); already_enabled[dev0][dev1] = true; } /* ===================== CudaCompNodeImpl::DeviceInfo ===================== */ -void CudaCompNodeImpl::DeviceInfo::init(const CompNodeEnv &env) { +void CudaCompNodeImpl::DeviceInfo::init(const CompNodeEnv& env) { mgb_assert(!mem_alloc); #if 0 // forward cudaMalloc mem_alloc = mem_alloc::DevMemAlloc::make_cuda_alloc(); #else - auto &&cuenv = env.cuda_env(); + auto&& cuenv = env.cuda_env(); cuenv.activate(); dev_num = cuenv.device; auto reserve_size = StaticData::get_mem_reserve_size(); @@ -600,9 +582,10 @@ void CudaCompNodeImpl::DeviceInfo::init(const CompNodeEnv &env) { mem_alloc->prealloc_config(sd->prealloc_config); auto align = env.property().mem_alignment; mem_alloc->alignment(align); - mgb_log_debug("cuda: gpu%d: name=`%s' dyn_mem_reserve=%.2fMiB alignment=0x%zx", - dev_num, cuenv.device_prop.name, - reserve_size / 1024.0 / 1024, align); + mgb_log_debug( + "cuda: gpu%d: name=`%s' dyn_mem_reserve=%.2fMiB alignment=0x%zx", + dev_num, cuenv.device_prop.name, reserve_size / 1024.0 / 1024, + align); #endif } @@ -631,14 +614,14 @@ bool CudaCompNodeImpl::check_global_finalized() { /* ===================== EventImpl ===================== */ -class CudaCompNode::EventImpl final: public EventImplHelper { +class CudaCompNode::EventImpl final : public EventImplHelper { bool m_init_finished = false; - CudaCompNodeImpl * const m_comp_node_impl; + CudaCompNodeImpl* const m_comp_node_impl; cudaEvent_t m_cuda_event; void do_record() override { m_comp_node_impl->activate(); - auto &&env = m_comp_node_impl->m_env.cuda_env(); + auto&& env = m_comp_node_impl->m_env.cuda_env(); MGB_CUDA_CHECK(cudaEventRecord(m_cuda_event, env.stream)); } @@ -649,56 +632,51 @@ class CudaCompNode::EventImpl final: public EventImplHelper { return true; if (err == cudaErrorNotReady) return false; - mgb_throw(CudaError, "failed to query event: %d: %s", - int(err), cudaGetErrorString(err)); + mgb_throw(CudaError, "failed to query event: %d: %s", int(err), + cudaGetErrorString(err)); } void host_wait_cv() override { MGB_CUDA_CHECK(cudaEventSynchronize(m_cuda_event)); } - double do_elapsed_time_until(EventImplHelper &end) override { + double do_elapsed_time_until(EventImplHelper& end) override { m_comp_node_impl->activate(); float ret = 0.0; - MGB_CUDA_CHECK(cudaEventElapsedTime(&ret, m_cuda_event, - static_cast(end).m_cuda_event)); + MGB_CUDA_CHECK(cudaEventElapsedTime( + &ret, m_cuda_event, static_cast(end).m_cuda_event)); return static_cast(ret) * 1e-3; } - void do_device_wait_by(Impl *cn_impl) override; - - public: + void do_device_wait_by(Impl* cn_impl) override; - EventImpl(CudaCompNodeImpl *comp_node_impl, size_t create_flags): - EventImplHelper(comp_node_impl, create_flags), - m_comp_node_impl{comp_node_impl} - { - m_comp_node_impl->activate(); - size_t cuda_flags = cudaEventDisableTiming; - if (create_flags & NEED_TIMER) - cuda_flags = 0; - MGB_CUDA_CHECK(cudaEventCreateWithFlags(&m_cuda_event, cuda_flags)); - m_init_finished = true; - } - - ~EventImpl() { - if (m_init_finished) { - MGB_TRY { - MGB_CUDA_CHECK(cudaEventDestroy(m_cuda_event)); - } MGB_CATCH(MegBrainError &exc, { - mgb_log_error("failed to destroy cuda event: %s", - exc.what()); - }) - } +public: + EventImpl(CudaCompNodeImpl* comp_node_impl, size_t create_flags) + : EventImplHelper(comp_node_impl, create_flags), + m_comp_node_impl{comp_node_impl} { + m_comp_node_impl->activate(); + size_t cuda_flags = cudaEventDisableTiming; + if (create_flags & NEED_TIMER) + cuda_flags = 0; + MGB_CUDA_CHECK(cudaEventCreateWithFlags(&m_cuda_event, cuda_flags)); + m_init_finished = true; + } + + ~EventImpl() { + if (m_init_finished) { + MGB_TRY { MGB_CUDA_CHECK(cudaEventDestroy(m_cuda_event)); } + MGB_CATCH(MegBrainError & exc, { + mgb_log_error("failed to destroy cuda event: %s", exc.what()); + }) } + } }; -std::unique_ptr -CudaCompNodeImpl::create_event(size_t flags) { +std::unique_ptr CudaCompNodeImpl::create_event(size_t flags) { return std::make_unique(this, flags); } -void CudaCompNode::EventImpl::do_device_wait_by(Impl *cn_impl) { +void CudaCompNode::EventImpl::do_device_wait_by(Impl* cn_impl) { if (cn_impl->dyn_typeinfo() == CudaCompNodeImpl::typeinfo()) { auto imp = static_cast(cn_impl); auto stream = imp->m_env.cuda_env().stream; @@ -716,7 +694,6 @@ void CudaCompNode::EventImpl::do_device_wait_by(Impl *cn_impl) { mgb_throw(MegBrainError, "unimplemented event device_wait_by config"); } - /* ===================== CudaCompNode static methods ===================== */ bool CudaCompNode::available() { @@ -729,7 +706,10 @@ bool CudaCompNode::available() { result = err == cudaSuccess && ndev > 0; if (!result) { mgb_log_warn("cuda unavailable: %s(%d) ndev=%d", - cudaGetErrorString(err), static_cast(err), ndev); + cudaGetErrorString(err), static_cast(err), ndev); + } + if (err == cudaErrorInitializationError) { + mgb_throw(std::runtime_error, "cuda initialization error."); } } return result; @@ -769,7 +749,7 @@ CompNode::Impl* CudaCompNode::load_cuda(const Locator& locator, "request gpu%d out of valid range [0, %d)", locator.device, nr_gpu); - auto &&sdptr = CudaCompNodeImpl::sd; + auto&& sdptr = CudaCompNodeImpl::sd; { MGB_LOCK_GUARD(CudaCompNodeImpl::sd_mtx); if (!sdptr) { @@ -777,17 +757,18 @@ CompNode::Impl* CudaCompNode::load_cuda(const Locator& locator, // global finalize using T = CudaCompNodeImpl::StaticData; static std::aligned_storage_t storage; - sdptr = new(&storage)T; + sdptr = new (&storage) T; } } - auto &&sd = *sdptr; + auto&& sd = *sdptr; MGB_LOCK_GUARD(sd.mtx); - CompNodeImpl *available_node = nullptr; - for (int i = 0; i < sd.nr_node; ++ i) { - auto &&cur = sd.node[i]; + CompNodeImpl* available_node = nullptr; + for (int i = 0; i < sd.nr_node; ++i) { + auto&& cur = sd.node[i]; if (cur.m_initialized) { - if (cur.m_locator == locator && cur.m_locator_logical == locator_logical) { + if (cur.m_locator == locator && + cur.m_locator_logical == locator_logical) { return &cur; } } else { @@ -797,11 +778,10 @@ CompNode::Impl* CudaCompNode::load_cuda(const Locator& locator, if (!available_node) { mgb_assert(sd.nr_node < sd.MAX_NR_COMP_NODE, - "too many CompNode allocated"); - available_node = &sd.node[sd.nr_node ++]; + "too many CompNode allocated"); + available_node = &sd.node[sd.nr_node++]; } - mgb_assert(locator.device < sd.MAX_NR_DEVICE, - "device number too large"); + mgb_assert(locator.device < sd.MAX_NR_DEVICE, "device number too large"); mgb_assert(!available_node->m_initialized); available_node->init(locator, locator_logical); @@ -816,13 +796,13 @@ void CudaCompNode::try_coalesce_all_free_memory() { return; size_t size = 0; - for (int i = 0; i < sd->nr_dev_used; ++ i) { - size += sd->dev_info[i].mem_alloc-> - gather_stream_free_blk_and_release_full(); + for (int i = 0; i < sd->nr_dev_used; ++i) { + size += sd->dev_info[i] + .mem_alloc->gather_stream_free_blk_and_release_full(); } if (size) { mgb_log_debug("%zu bytes freed by try_coalesce_all_free_memory()", - size); + size); } } @@ -831,9 +811,9 @@ void CudaCompNode::sync_all() { if (!sd) return; - for (int i = 0; ; ++ i) { + for (int i = 0;; ++i) { // ensure async init finished - CompNodeEnv *env; + CompNodeEnv* env; { MGB_LOCK_GUARD(sd->mtx); if (i >= sd->nr_node) { @@ -851,12 +831,12 @@ void CudaCompNode::sync_all() { } } -void CudaCompNode::foreach(thin_function callback) { +void CudaCompNode::foreach (thin_function callback) { auto sd = CudaCompNodeImpl::sd; if (!sd) return; - for (int i = 0; ; ++ i) { + for (int i = 0;; ++i) { CompNode cur; { MGB_LOCK_GUARD(sd->mtx); @@ -875,8 +855,9 @@ size_t CudaCompNode::get_device_count(bool warn) { if (cnt == -1) { auto err = cudaGetDeviceCount(&cnt); if (err != cudaSuccess) { - if (warn) mgb_log_error("cudaGetDeviceCount failed: %s (err %d)", - cudaGetErrorString(err), int(err)); + if (warn) + mgb_log_error("cudaGetDeviceCount failed: %s (err %d)", + cudaGetErrorString(err), int(err)); cnt = 0; } mgb_assert(cnt >= 0); @@ -884,26 +865,26 @@ size_t CudaCompNode::get_device_count(bool warn) { return cnt; } -void CudaCompNode::set_prealloc_config(size_t alignment, size_t min_req, +void CudaCompNode::set_prealloc_config(size_t alignment, size_t min_req, size_t max_overhead, double growth_factor) { - auto &&sdptr = CudaCompNodeImpl::sd; + auto&& sdptr = CudaCompNodeImpl::sd; { MGB_LOCK_GUARD(CudaCompNodeImpl::sd_mtx); if (!sdptr) { using T = CudaCompNodeImpl::StaticData; static std::aligned_storage_t storage; - sdptr = new(&storage)T; + sdptr = new (&storage) T; sdptr->prealloc_config.alignment = alignment; sdptr->prealloc_config.min_req = min_req; sdptr->prealloc_config.growth_factor = growth_factor; sdptr->prealloc_config.max_overhead = max_overhead; } else { mgb_log_warn( - "invalid call to set_prealloc_config, will fallback to " - "default config; " - "prealloc_config should be specified before any CUDA " - "memory allocation"); + "invalid call to set_prealloc_config, will fallback to " + "default config; " + "prealloc_config should be specified before any CUDA " + "memory allocation"); } } } @@ -913,27 +894,23 @@ void CudaCompNode::set_prealloc_config(size_t alignment, size_t min_req, bool CudaCompNode::available() { return false; } -void CudaCompNode::try_coalesce_all_free_memory() { -} -void CudaCompNode::foreach(thin_function) { -} -void CudaCompNode::finalize() { -} +void CudaCompNode::try_coalesce_all_free_memory() {} +void CudaCompNode::foreach (thin_function) {} +void CudaCompNode::finalize() {} size_t CudaCompNode::get_device_count(bool warn) { return 0; } CudaCompNode::Impl* CudaCompNode::load_cuda(const Locator&, const Locator&) { mgb_throw(MegBrainError, "cuda disabled at compile time"); } -void CudaCompNode::sync_all() { -} +void CudaCompNode::sync_all() {} -void CudaCompNode::set_prealloc_config(size_t alignment, size_t min_req, +void CudaCompNode::set_prealloc_config(size_t alignment, size_t min_req, size_t max_overhead, double growth_factor) {} #undef err -#endif // MGB_CUDA +#endif // MGB_CUDA // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}