@@ -47,33 +47,6 @@ const char* get_param_name(CollectiveComm::Param param) { | |||||
} | } | ||||
} | } | ||||
MegRay::DType get_megray_dtype(megdnn::DType dtype) { | |||||
switch(dtype.enumv()) { | |||||
case DTypeEnum::Int8: | |||||
return MegRay::DType::MEGRAY_INT8; | |||||
case DTypeEnum::Int32: | |||||
return MegRay::DType::MEGRAY_INT32; | |||||
case DTypeEnum::Float32: | |||||
return MegRay::DType::MEGRAY_FLOAT32; | |||||
#ifndef MEGDNN_DISABLE_FLOAT16 | |||||
case DTypeEnum::Float16: | |||||
return MegRay::DType::MEGRAY_FLOAT16; | |||||
#endif | |||||
default: | |||||
mgb_throw(MegBrainError, "bad CollectiveComm dtype"); | |||||
} | |||||
} | |||||
MegRay::Backend get_megray_backend(const std::string& backend) { | |||||
if (backend == "nccl") { | |||||
return MegRay::MEGRAY_NCCL; | |||||
} else if (backend == "ucx") { | |||||
return MegRay::MEGRAY_UCX; | |||||
} else { | |||||
mgb_throw(MegBrainError, "back CollectiveComm backend"); | |||||
} | |||||
} | |||||
cudaStream_t get_stream(VarNode* var) { | cudaStream_t get_stream(VarNode* var) { | ||||
return CompNodeEnv::from_comp_node(var->comp_node()).cuda_env().stream; | return CompNodeEnv::from_comp_node(var->comp_node()).cuda_env().stream; | ||||
} | } | ||||
@@ -82,8 +82,9 @@ void RemoteSend::scn_do_execute() { | |||||
for (size_t i = 0; i < ishp.ndim; i++) { | for (size_t i = 0; i < ishp.ndim; i++) { | ||||
data_size *= ishp[i]; | data_size *= ishp[i]; | ||||
} | } | ||||
data_size *= tensor.dtype().size(); | |||||
auto status = m_megray_comm->send(tensor.raw_ptr(), data_size, 1, m_megray_ctx); | |||||
auto status = m_megray_comm->send(tensor.raw_ptr(), data_size, | |||||
get_megray_dtype(tensor.dtype()), | |||||
1, m_megray_ctx); | |||||
mgb_assert(status == MegRay::MEGRAY_OK, "MegRay send failed"); | mgb_assert(status == MegRay::MEGRAY_OK, "MegRay send failed"); | ||||
if (m_is_grad) { | if (m_is_grad) { | ||||
@@ -192,8 +193,9 @@ void RemoteRecv::scn_do_execute() { | |||||
for (size_t i = 0; i < ishp.ndim; i++) { | for (size_t i = 0; i < ishp.ndim; i++) { | ||||
data_size *= ishp[i]; | data_size *= ishp[i]; | ||||
} | } | ||||
data_size *= tensor.dtype().size(); | |||||
auto status = m_megray_comm->recv(tensor.raw_ptr(), data_size, 0, m_megray_ctx); | |||||
auto status = m_megray_comm->recv(tensor.raw_ptr(), data_size, | |||||
get_megray_dtype(tensor.dtype()), | |||||
0, m_megray_ctx); | |||||
mgb_assert(status == MegRay::MEGRAY_OK, "MegRay recv failed"); | mgb_assert(status == MegRay::MEGRAY_OK, "MegRay recv failed"); | ||||
} | } | ||||
@@ -14,6 +14,33 @@ | |||||
using namespace mgb; | using namespace mgb; | ||||
using namespace opr; | using namespace opr; | ||||
MegRay::DType mgb::opr::get_megray_dtype(megdnn::DType dtype) { | |||||
switch(dtype.enumv()) { | |||||
case DTypeEnum::Int8: | |||||
return MegRay::DType::MEGRAY_INT8; | |||||
case DTypeEnum::Int32: | |||||
return MegRay::DType::MEGRAY_INT32; | |||||
case DTypeEnum::Float32: | |||||
return MegRay::DType::MEGRAY_FLOAT32; | |||||
#ifndef MEGDNN_DISABLE_FLOAT16 | |||||
case DTypeEnum::Float16: | |||||
return MegRay::DType::MEGRAY_FLOAT16; | |||||
#endif | |||||
default: | |||||
mgb_throw(MegBrainError, "bad CollectiveComm dtype"); | |||||
} | |||||
} | |||||
MegRay::Backend mgb::opr::get_megray_backend(const std::string& backend) { | |||||
if (backend == "nccl") { | |||||
return MegRay::MEGRAY_NCCL; | |||||
} else if (backend == "ucx") { | |||||
return MegRay::MEGRAY_UCX; | |||||
} else { | |||||
mgb_throw(MegBrainError, "back CollectiveComm backend"); | |||||
} | |||||
} | |||||
bool MegRayCommBuilder::find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm) { | bool MegRayCommBuilder::find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm) { | ||||
std::unique_lock<std::mutex> lk(m_map_mtx); | std::unique_lock<std::mutex> lk(m_map_mtx); | ||||
auto it = m_megray_comms.find(hash); | auto it = m_megray_comms.find(hash); | ||||
@@ -13,13 +13,16 @@ | |||||
#include <mutex> | #include <mutex> | ||||
#include "megbrain/utils/metahelper.h" | |||||
#include "megbrain/opr/group_manager.h" | #include "megbrain/opr/group_manager.h" | ||||
#include "megray.h" | #include "megray.h" | ||||
namespace mgb { | namespace mgb { | ||||
namespace opr { | namespace opr { | ||||
MegRay::DType get_megray_dtype(megdnn::DType); | |||||
MegRay::Backend get_megray_backend(const std::string& backend); | |||||
/*! | /*! | ||||
* gather MegRay unique ids and build communicator, use hash for deduplication | * gather MegRay unique ids and build communicator, use hash for deduplication | ||||
*/ | */ | ||||