From 116eee5231df575b8fa885dd5d98403ae428d9d5 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 11 Sep 2020 11:13:13 +0800 Subject: [PATCH] build(third_party): update megray GitOrigin-RevId: da5e05f82b5112474d51f9eab78318b1d6432742 --- src/opr-mm/impl/collective_comm.cpp | 27 ------------------------- src/opr-mm/impl/io_remote.cpp | 10 +++++---- src/opr-mm/impl/megray_helper.cpp | 27 +++++++++++++++++++++++++ src/opr-mm/include/megbrain/opr/megray_helper.h | 5 ++++- 4 files changed, 37 insertions(+), 32 deletions(-) diff --git a/src/opr-mm/impl/collective_comm.cpp b/src/opr-mm/impl/collective_comm.cpp index 8305b188..46e43872 100644 --- a/src/opr-mm/impl/collective_comm.cpp +++ b/src/opr-mm/impl/collective_comm.cpp @@ -47,33 +47,6 @@ const char* get_param_name(CollectiveComm::Param param) { } } -MegRay::DType get_megray_dtype(megdnn::DType dtype) { - switch(dtype.enumv()) { - case DTypeEnum::Int8: - return MegRay::DType::MEGRAY_INT8; - case DTypeEnum::Int32: - return MegRay::DType::MEGRAY_INT32; - case DTypeEnum::Float32: - return MegRay::DType::MEGRAY_FLOAT32; -#ifndef MEGDNN_DISABLE_FLOAT16 - case DTypeEnum::Float16: - return MegRay::DType::MEGRAY_FLOAT16; -#endif - default: - mgb_throw(MegBrainError, "bad CollectiveComm dtype"); - } -} - -MegRay::Backend get_megray_backend(const std::string& backend) { - if (backend == "nccl") { - return MegRay::MEGRAY_NCCL; - } else if (backend == "ucx") { - return MegRay::MEGRAY_UCX; - } else { - mgb_throw(MegBrainError, "back CollectiveComm backend"); - } -} - cudaStream_t get_stream(VarNode* var) { return CompNodeEnv::from_comp_node(var->comp_node()).cuda_env().stream; } diff --git a/src/opr-mm/impl/io_remote.cpp b/src/opr-mm/impl/io_remote.cpp index d26b2491..cddd08fc 100644 --- a/src/opr-mm/impl/io_remote.cpp +++ b/src/opr-mm/impl/io_remote.cpp @@ -82,8 +82,9 @@ void RemoteSend::scn_do_execute() { for (size_t i = 0; i < ishp.ndim; i++) { data_size *= ishp[i]; } - data_size *= tensor.dtype().size(); - auto status = m_megray_comm->send(tensor.raw_ptr(), data_size, 1, m_megray_ctx); + auto status = m_megray_comm->send(tensor.raw_ptr(), data_size, + get_megray_dtype(tensor.dtype()), + 1, m_megray_ctx); mgb_assert(status == MegRay::MEGRAY_OK, "MegRay send failed"); if (m_is_grad) { @@ -192,8 +193,9 @@ void RemoteRecv::scn_do_execute() { for (size_t i = 0; i < ishp.ndim; i++) { data_size *= ishp[i]; } - data_size *= tensor.dtype().size(); - auto status = m_megray_comm->recv(tensor.raw_ptr(), data_size, 0, m_megray_ctx); + auto status = m_megray_comm->recv(tensor.raw_ptr(), data_size, + get_megray_dtype(tensor.dtype()), + 0, m_megray_ctx); mgb_assert(status == MegRay::MEGRAY_OK, "MegRay recv failed"); } diff --git a/src/opr-mm/impl/megray_helper.cpp b/src/opr-mm/impl/megray_helper.cpp index 6fc70b04..1d99b907 100644 --- a/src/opr-mm/impl/megray_helper.cpp +++ b/src/opr-mm/impl/megray_helper.cpp @@ -14,6 +14,33 @@ using namespace mgb; using namespace opr; +MegRay::DType mgb::opr::get_megray_dtype(megdnn::DType dtype) { + switch(dtype.enumv()) { + case DTypeEnum::Int8: + return MegRay::DType::MEGRAY_INT8; + case DTypeEnum::Int32: + return MegRay::DType::MEGRAY_INT32; + case DTypeEnum::Float32: + return MegRay::DType::MEGRAY_FLOAT32; +#ifndef MEGDNN_DISABLE_FLOAT16 + case DTypeEnum::Float16: + return MegRay::DType::MEGRAY_FLOAT16; +#endif + default: + mgb_throw(MegBrainError, "bad CollectiveComm dtype"); + } +} + +MegRay::Backend mgb::opr::get_megray_backend(const std::string& backend) { + if (backend == "nccl") { + return MegRay::MEGRAY_NCCL; + } else if (backend == "ucx") { + return MegRay::MEGRAY_UCX; + } else { + mgb_throw(MegBrainError, "back CollectiveComm backend"); + } +} + bool MegRayCommBuilder::find(uint64_t hash, std::shared_ptr& comm) { std::unique_lock lk(m_map_mtx); auto it = m_megray_comms.find(hash); diff --git a/src/opr-mm/include/megbrain/opr/megray_helper.h b/src/opr-mm/include/megbrain/opr/megray_helper.h index 4e9117bd..da4d47a1 100644 --- a/src/opr-mm/include/megbrain/opr/megray_helper.h +++ b/src/opr-mm/include/megbrain/opr/megray_helper.h @@ -13,13 +13,16 @@ #include -#include "megbrain/utils/metahelper.h" #include "megbrain/opr/group_manager.h" #include "megray.h" namespace mgb { namespace opr { +MegRay::DType get_megray_dtype(megdnn::DType); + +MegRay::Backend get_megray_backend(const std::string& backend); + /*! * gather MegRay unique ids and build communicator, use hash for deduplication */