From d3154a45bc5e7166cb2fe064ce65d6b955cec390 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 18 Nov 2020 13:30:52 +0800 Subject: [PATCH] feat(distributed): support distributed opr for rocm GitOrigin-RevId: 4840100d07dbaa2b7d8e3e113b444ddf81eeea51 --- CMakeLists.txt | 6 ++++-- src/core/impl/comp_node/comp_node.cpp | 2 ++ src/core/impl/comp_node/rocm/comp_node.cpp | 16 ++++++++++++++++ src/opr-mm/impl/collective_comm.cpp | 9 +-------- src/opr-mm/impl/io_remote.cpp | 8 ++------ src/opr-mm/impl/megray_helper.cpp | 13 +++++++++++++ src/opr-mm/impl/zmq_rpc.cpp | 2 -- src/opr-mm/include/megbrain/opr/megray_helper.h | 4 ++++ src/opr-mm/include/megbrain/opr/zmq_rpc.h | 2 -- 9 files changed, 42 insertions(+), 20 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index b78b7a95..59ca13c4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -339,8 +339,8 @@ if(MGE_BUILD_IMPERATIVE_RT) set(CMAKE_CXX_STANDARD 17) endif() -if(NOT MGE_WITH_CUDA) - message(STATUS "Disable distributed support, as CUDA is not enabled.") +if(NOT ${MGE_WITH_CUDA} AND NOT ${MGE_WITH_ROCM}) + message(STATUS "Disable distributed support, as both CUDA and ROCm are disabled.") set(MGE_WITH_DISTRIBUTED OFF) endif() @@ -903,6 +903,8 @@ if(MGE_WITH_JIT_MLIR OR MGE_BUILD_IMPERATIVE_RT) endif() if(MGE_WITH_DISTRIBUTED) + set(MEGRAY_WITH_NCCL ${MGE_WITH_CUDA} CACHE BOOL "Override MegRay option" FORCE) + set(MEGRAY_WITH_RCCL ${MGE_WITH_ROCM} CACHE BOOL "Override MegRay option" FORCE) add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/MegRay) endif() diff --git a/src/core/impl/comp_node/comp_node.cpp b/src/core/impl/comp_node/comp_node.cpp index 47b9d4e0..a9a17c41 100644 --- a/src/core/impl/comp_node/comp_node.cpp +++ b/src/core/impl/comp_node/comp_node.cpp @@ -79,6 +79,8 @@ namespace { if (g_unspec_locator_type == DT::UNSPEC) { if (CudaCompNode::available()) { g_unspec_locator_type = DT::CUDA; + } else if (ROCmCompNode::available()) { + g_unspec_locator_type = DT::ROCM; } else { g_unspec_locator_type = DT::CPU; } diff --git a/src/core/impl/comp_node/rocm/comp_node.cpp b/src/core/impl/comp_node/rocm/comp_node.cpp index 07a53d7a..4f553d1a 100644 --- a/src/core/impl/comp_node/rocm/comp_node.cpp +++ b/src/core/impl/comp_node/rocm/comp_node.cpp @@ -217,6 +217,11 @@ public: Locator locator() override { return m_locator; } Locator locator_logical() override { return m_locator_logical; } + + uint64_t get_uid() override { return m_uid; } + +private: + uint64_t m_uid; }; MGB_DYN_TYPE_OBJ_FINAL_IMPL(ROCmCompNode::CompNodeImpl); @@ -278,6 +283,17 @@ void ROCmCompNodeImpl::init(const Locator& locator, m_locator_logical = locator_logical; m_initialized = true; +#if defined(__linux__) || defined(TARGET_OS_MAC) + FILE *fp; + fp = fopen("/dev/urandom", "r"); + mgb_assert(fread(&m_uid, sizeof(m_uid), 1, fp) == 1); + fclose(fp); +#else + m_uid = std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch() + ).count(); +#endif + auto on_succ = [this](hipStream_t stream) { auto locator = m_locator; log_comp_node_created(locator, m_locator_logical); diff --git a/src/opr-mm/impl/collective_comm.cpp b/src/opr-mm/impl/collective_comm.cpp index 786a7516..f28617fb 100644 --- a/src/opr-mm/impl/collective_comm.cpp +++ b/src/opr-mm/impl/collective_comm.cpp @@ -47,9 +47,6 @@ const char* get_param_name(CollectiveComm::Param param) { } } -cudaStream_t get_stream(VarNode* var) { - return CompNodeEnv::from_comp_node(var->comp_node()).cuda_env().stream; -} } // anonymous namespace /* ================= ModeTrait ================= */ @@ -519,8 +516,6 @@ CollectiveComm::CollectiveComm( // add input mgb_assert(inputs.size() <= 1, "one or zero input expected, got %zu", inputs.size()); if (inputs.size() > 0) { - mgb_assert(inputs[0]->comp_node().device_type() == CompNode::DeviceType::CUDA, - "CollectiveComm currectly only supports CUDA"); add_input({inputs[0]}); } @@ -531,8 +526,6 @@ CollectiveComm::CollectiveComm( const auto& cns = config.comp_node(); mgb_assert(cns.size() <= 1, "one or zero comp node expected, got %zu", cns.size()); if (cns.size() > 0) { - mgb_assert(cns[0].device_type() == CompNode::DeviceType::CUDA, - "CollectiveComm currectly only supports CUDA"); output(0)->comp_node(cns[0]); } else { output(0)->comp_node(inputs[0]->comp_node()); @@ -609,7 +602,7 @@ void CollectiveComm::opr_register() { reg_info.hash, m_key, m_nr_devices, m_rank, get_megray_backend(m_backend), m_group_client); - m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0))); + m_megray_ctx = get_megray_context(output(0)->comp_node()); m_init = true; } diff --git a/src/opr-mm/impl/io_remote.cpp b/src/opr-mm/impl/io_remote.cpp index bc352617..ff5c7aa5 100644 --- a/src/opr-mm/impl/io_remote.cpp +++ b/src/opr-mm/impl/io_remote.cpp @@ -18,10 +18,6 @@ using namespace mgb; using namespace opr; -cudaStream_t get_stream(VarNode* var) { - return CompNodeEnv::from_comp_node(var->comp_node()).cuda_env().stream; -} - /* ===================== RemoteSend ===================== */ MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteSend); @@ -70,7 +66,7 @@ void RemoteSend::scn_do_execute() { m_megray_comm = MegRayCommBuilder::get_megray_comm( reg_info.hash, m_key, 2, 0, MegRay::MEGRAY_NCCL, m_group_client); - m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0))); + m_megray_ctx = get_megray_context(output(0)->comp_node()); m_init = true; } @@ -207,7 +203,7 @@ void RemoteRecv::scn_do_execute() { m_megray_comm = MegRayCommBuilder::get_megray_comm( reg_info.hash, m_key, 2, 1, MegRay::MEGRAY_NCCL, m_group_client); - m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0))); + m_megray_ctx = get_megray_context(output(0)->comp_node()); m_init = true; } diff --git a/src/opr-mm/impl/megray_helper.cpp b/src/opr-mm/impl/megray_helper.cpp index 2f4c688f..ba124ecb 100644 --- a/src/opr-mm/impl/megray_helper.cpp +++ b/src/opr-mm/impl/megray_helper.cpp @@ -10,6 +10,7 @@ */ #include "megbrain/opr/megray_helper.h" +#include "megbrain/comp_node_env.h" using namespace mgb; using namespace opr; @@ -34,6 +35,8 @@ MegRay::DType mgb::opr::get_megray_dtype(megdnn::DType dtype) { MegRay::Backend mgb::opr::get_megray_backend(const std::string& backend) { if (backend == "nccl") { return MegRay::MEGRAY_NCCL; + } else if (backend == "rccl") { + return MegRay::MEGRAY_RCCL; } else if (backend == "ucx") { return MegRay::MEGRAY_UCX; } else { @@ -41,6 +44,16 @@ MegRay::Backend mgb::opr::get_megray_backend(const std::string& backend) { } } +std::shared_ptr mgb::opr::get_megray_context(CompNode comp_node){ +#if MGB_CUDA + return MegRay::CudaContext::make(CompNodeEnv::from_comp_node(comp_node).cuda_env().stream); +#elif MGB_ROCM + return MegRay::HipContext::make(CompNodeEnv::from_comp_node(comp_node).rocm_env().stream); +#else +#error "neither CUDA nor ROCm is enabled" +#endif +} + bool MegRayCommBuilder::find(uint64_t hash, std::shared_ptr& comm) { std::unique_lock lk(m_map_mtx); auto it = m_megray_comms.find(hash); diff --git a/src/opr-mm/impl/zmq_rpc.cpp b/src/opr-mm/impl/zmq_rpc.cpp index 70cc4068..af64f152 100644 --- a/src/opr-mm/impl/zmq_rpc.cpp +++ b/src/opr-mm/impl/zmq_rpc.cpp @@ -1,6 +1,5 @@ #include "megbrain_build_config.h" -#if MGB_CUDA #include "megbrain/opr/zmq_rpc.h" #include "megbrain/common.h" #include "megbrain/exception.h" @@ -228,4 +227,3 @@ void ZmqRpcClient::request(message_t& request, message_t& reply) { DISCARD_RETVAL(client->recv(reply)); add_socket(client); } -#endif // MGB_CUDA diff --git a/src/opr-mm/include/megbrain/opr/megray_helper.h b/src/opr-mm/include/megbrain/opr/megray_helper.h index ac5f4787..3f110660 100644 --- a/src/opr-mm/include/megbrain/opr/megray_helper.h +++ b/src/opr-mm/include/megbrain/opr/megray_helper.h @@ -12,7 +12,9 @@ #pragma once #include +#include +#include "megbrain/comp_node.h" #include "megbrain/opr/group_manager.h" #include "megray.h" @@ -23,6 +25,8 @@ MegRay::DType get_megray_dtype(megdnn::DType); MegRay::Backend get_megray_backend(const std::string& backend); +std::shared_ptr get_megray_context(CompNode comp_node); + /*! * gather MegRay unique ids and build communicator, use hash for deduplication */ diff --git a/src/opr-mm/include/megbrain/opr/zmq_rpc.h b/src/opr-mm/include/megbrain/opr/zmq_rpc.h index 49048562..49f43ad9 100644 --- a/src/opr-mm/include/megbrain/opr/zmq_rpc.h +++ b/src/opr-mm/include/megbrain/opr/zmq_rpc.h @@ -2,7 +2,6 @@ #include "megbrain_build_config.h" -#if MGB_CUDA #include #include #include @@ -101,4 +100,3 @@ private: std::vector> m_own_sockets; }; } // namespace ZmqRpc -#endif