Browse Source

build(third_party): update megray

GitOrigin-RevId: da5e05f82b
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
116eee5231
4 changed files with 37 additions and 32 deletions
  1. +0
    -27
      src/opr-mm/impl/collective_comm.cpp
  2. +6
    -4
      src/opr-mm/impl/io_remote.cpp
  3. +27
    -0
      src/opr-mm/impl/megray_helper.cpp
  4. +4
    -1
      src/opr-mm/include/megbrain/opr/megray_helper.h

+ 0
- 27
src/opr-mm/impl/collective_comm.cpp View File

@@ -47,33 +47,6 @@ const char* get_param_name(CollectiveComm::Param param) {
}
}

MegRay::DType get_megray_dtype(megdnn::DType dtype) {
switch(dtype.enumv()) {
case DTypeEnum::Int8:
return MegRay::DType::MEGRAY_INT8;
case DTypeEnum::Int32:
return MegRay::DType::MEGRAY_INT32;
case DTypeEnum::Float32:
return MegRay::DType::MEGRAY_FLOAT32;
#ifndef MEGDNN_DISABLE_FLOAT16
case DTypeEnum::Float16:
return MegRay::DType::MEGRAY_FLOAT16;
#endif
default:
mgb_throw(MegBrainError, "bad CollectiveComm dtype");
}
}

MegRay::Backend get_megray_backend(const std::string& backend) {
if (backend == "nccl") {
return MegRay::MEGRAY_NCCL;
} else if (backend == "ucx") {
return MegRay::MEGRAY_UCX;
} else {
mgb_throw(MegBrainError, "back CollectiveComm backend");
}
}

cudaStream_t get_stream(VarNode* var) {
return CompNodeEnv::from_comp_node(var->comp_node()).cuda_env().stream;
}


+ 6
- 4
src/opr-mm/impl/io_remote.cpp View File

@@ -82,8 +82,9 @@ void RemoteSend::scn_do_execute() {
for (size_t i = 0; i < ishp.ndim; i++) {
data_size *= ishp[i];
}
data_size *= tensor.dtype().size();
auto status = m_megray_comm->send(tensor.raw_ptr(), data_size, 1, m_megray_ctx);
auto status = m_megray_comm->send(tensor.raw_ptr(), data_size,
get_megray_dtype(tensor.dtype()),
1, m_megray_ctx);
mgb_assert(status == MegRay::MEGRAY_OK, "MegRay send failed");

if (m_is_grad) {
@@ -192,8 +193,9 @@ void RemoteRecv::scn_do_execute() {
for (size_t i = 0; i < ishp.ndim; i++) {
data_size *= ishp[i];
}
data_size *= tensor.dtype().size();
auto status = m_megray_comm->recv(tensor.raw_ptr(), data_size, 0, m_megray_ctx);
auto status = m_megray_comm->recv(tensor.raw_ptr(), data_size,
get_megray_dtype(tensor.dtype()),
0, m_megray_ctx);
mgb_assert(status == MegRay::MEGRAY_OK, "MegRay recv failed");
}



+ 27
- 0
src/opr-mm/impl/megray_helper.cpp View File

@@ -14,6 +14,33 @@
using namespace mgb;
using namespace opr;

MegRay::DType mgb::opr::get_megray_dtype(megdnn::DType dtype) {
switch(dtype.enumv()) {
case DTypeEnum::Int8:
return MegRay::DType::MEGRAY_INT8;
case DTypeEnum::Int32:
return MegRay::DType::MEGRAY_INT32;
case DTypeEnum::Float32:
return MegRay::DType::MEGRAY_FLOAT32;
#ifndef MEGDNN_DISABLE_FLOAT16
case DTypeEnum::Float16:
return MegRay::DType::MEGRAY_FLOAT16;
#endif
default:
mgb_throw(MegBrainError, "bad CollectiveComm dtype");
}
}

MegRay::Backend mgb::opr::get_megray_backend(const std::string& backend) {
if (backend == "nccl") {
return MegRay::MEGRAY_NCCL;
} else if (backend == "ucx") {
return MegRay::MEGRAY_UCX;
} else {
mgb_throw(MegBrainError, "back CollectiveComm backend");
}
}

bool MegRayCommBuilder::find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm) {
std::unique_lock<std::mutex> lk(m_map_mtx);
auto it = m_megray_comms.find(hash);


+ 4
- 1
src/opr-mm/include/megbrain/opr/megray_helper.h View File

@@ -13,13 +13,16 @@

#include <mutex>

#include "megbrain/utils/metahelper.h"
#include "megbrain/opr/group_manager.h"
#include "megray.h"

namespace mgb {
namespace opr {

MegRay::DType get_megray_dtype(megdnn::DType);

MegRay::Backend get_megray_backend(const std::string& backend);

/*!
* gather MegRay unique ids and build communicator, use hash for deduplication
*/


Loading…
Cancel
Save