Browse Source

feat(distributed): support distributed opr for rocm

GitOrigin-RevId: 4840100d07
release-1.4
Megvii Engine Team 4 years ago
parent
commit
d3154a45bc
9 changed files with 42 additions and 20 deletions
  1. +4
    -2
      CMakeLists.txt
  2. +2
    -0
      src/core/impl/comp_node/comp_node.cpp
  3. +16
    -0
      src/core/impl/comp_node/rocm/comp_node.cpp
  4. +1
    -8
      src/opr-mm/impl/collective_comm.cpp
  5. +2
    -6
      src/opr-mm/impl/io_remote.cpp
  6. +13
    -0
      src/opr-mm/impl/megray_helper.cpp
  7. +0
    -2
      src/opr-mm/impl/zmq_rpc.cpp
  8. +4
    -0
      src/opr-mm/include/megbrain/opr/megray_helper.h
  9. +0
    -2
      src/opr-mm/include/megbrain/opr/zmq_rpc.h

+ 4
- 2
CMakeLists.txt View File

@@ -339,8 +339,8 @@ if(MGE_BUILD_IMPERATIVE_RT)
set(CMAKE_CXX_STANDARD 17)
endif()

if(NOT MGE_WITH_CUDA)
message(STATUS "Disable distributed support, as CUDA is not enabled.")
if(NOT ${MGE_WITH_CUDA} AND NOT ${MGE_WITH_ROCM})
message(STATUS "Disable distributed support, as both CUDA and ROCm are disabled.")
set(MGE_WITH_DISTRIBUTED OFF)
endif()

@@ -903,6 +903,8 @@ if(MGE_WITH_JIT_MLIR OR MGE_BUILD_IMPERATIVE_RT)
endif()

if(MGE_WITH_DISTRIBUTED)
set(MEGRAY_WITH_NCCL ${MGE_WITH_CUDA} CACHE BOOL "Override MegRay option" FORCE)
set(MEGRAY_WITH_RCCL ${MGE_WITH_ROCM} CACHE BOOL "Override MegRay option" FORCE)
add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/MegRay)
endif()



+ 2
- 0
src/core/impl/comp_node/comp_node.cpp View File

@@ -79,6 +79,8 @@ namespace {
if (g_unspec_locator_type == DT::UNSPEC) {
if (CudaCompNode::available()) {
g_unspec_locator_type = DT::CUDA;
} else if (ROCmCompNode::available()) {
g_unspec_locator_type = DT::ROCM;
} else {
g_unspec_locator_type = DT::CPU;
}


+ 16
- 0
src/core/impl/comp_node/rocm/comp_node.cpp View File

@@ -217,6 +217,11 @@ public:
Locator locator() override { return m_locator; }

Locator locator_logical() override { return m_locator_logical; }

uint64_t get_uid() override { return m_uid; }

private:
uint64_t m_uid;
};
MGB_DYN_TYPE_OBJ_FINAL_IMPL(ROCmCompNode::CompNodeImpl);

@@ -278,6 +283,17 @@ void ROCmCompNodeImpl::init(const Locator& locator,
m_locator_logical = locator_logical;
m_initialized = true;

#if defined(__linux__) || defined(TARGET_OS_MAC)
FILE *fp;
fp = fopen("/dev/urandom", "r");
mgb_assert(fread(&m_uid, sizeof(m_uid), 1, fp) == 1);
fclose(fp);
#else
m_uid = std::chrono::duration_cast<std::chrono::nanoseconds>(
std::chrono::system_clock::now().time_since_epoch()
).count();
#endif

auto on_succ = [this](hipStream_t stream) {
auto locator = m_locator;
log_comp_node_created(locator, m_locator_logical);


+ 1
- 8
src/opr-mm/impl/collective_comm.cpp View File

@@ -47,9 +47,6 @@ const char* get_param_name(CollectiveComm::Param param) {
}
}

cudaStream_t get_stream(VarNode* var) {
return CompNodeEnv::from_comp_node(var->comp_node()).cuda_env().stream;
}
} // anonymous namespace

/* ================= ModeTrait ================= */
@@ -519,8 +516,6 @@ CollectiveComm::CollectiveComm(
// add input
mgb_assert(inputs.size() <= 1, "one or zero input expected, got %zu", inputs.size());
if (inputs.size() > 0) {
mgb_assert(inputs[0]->comp_node().device_type() == CompNode::DeviceType::CUDA,
"CollectiveComm currectly only supports CUDA");
add_input({inputs[0]});
}

@@ -531,8 +526,6 @@ CollectiveComm::CollectiveComm(
const auto& cns = config.comp_node();
mgb_assert(cns.size() <= 1, "one or zero comp node expected, got %zu", cns.size());
if (cns.size() > 0) {
mgb_assert(cns[0].device_type() == CompNode::DeviceType::CUDA,
"CollectiveComm currectly only supports CUDA");
output(0)->comp_node(cns[0]);
} else {
output(0)->comp_node(inputs[0]->comp_node());
@@ -609,7 +602,7 @@ void CollectiveComm::opr_register() {
reg_info.hash, m_key, m_nr_devices, m_rank,
get_megray_backend(m_backend), m_group_client);

m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0)));
m_megray_ctx = get_megray_context(output(0)->comp_node());

m_init = true;
}


+ 2
- 6
src/opr-mm/impl/io_remote.cpp View File

@@ -18,10 +18,6 @@
using namespace mgb;
using namespace opr;

cudaStream_t get_stream(VarNode* var) {
return CompNodeEnv::from_comp_node(var->comp_node()).cuda_env().stream;
}

/* ===================== RemoteSend ===================== */

MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteSend);
@@ -70,7 +66,7 @@ void RemoteSend::scn_do_execute() {
m_megray_comm = MegRayCommBuilder::get_megray_comm(
reg_info.hash, m_key, 2, 0, MegRay::MEGRAY_NCCL, m_group_client);

m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0)));
m_megray_ctx = get_megray_context(output(0)->comp_node());

m_init = true;
}
@@ -207,7 +203,7 @@ void RemoteRecv::scn_do_execute() {
m_megray_comm = MegRayCommBuilder::get_megray_comm(
reg_info.hash, m_key, 2, 1, MegRay::MEGRAY_NCCL, m_group_client);

m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0)));
m_megray_ctx = get_megray_context(output(0)->comp_node());

m_init = true;
}


+ 13
- 0
src/opr-mm/impl/megray_helper.cpp View File

@@ -10,6 +10,7 @@
*/

#include "megbrain/opr/megray_helper.h"
#include "megbrain/comp_node_env.h"

using namespace mgb;
using namespace opr;
@@ -34,6 +35,8 @@ MegRay::DType mgb::opr::get_megray_dtype(megdnn::DType dtype) {
MegRay::Backend mgb::opr::get_megray_backend(const std::string& backend) {
if (backend == "nccl") {
return MegRay::MEGRAY_NCCL;
} else if (backend == "rccl") {
return MegRay::MEGRAY_RCCL;
} else if (backend == "ucx") {
return MegRay::MEGRAY_UCX;
} else {
@@ -41,6 +44,16 @@ MegRay::Backend mgb::opr::get_megray_backend(const std::string& backend) {
}
}

std::shared_ptr<MegRay::Context> mgb::opr::get_megray_context(CompNode comp_node){
#if MGB_CUDA
return MegRay::CudaContext::make(CompNodeEnv::from_comp_node(comp_node).cuda_env().stream);
#elif MGB_ROCM
return MegRay::HipContext::make(CompNodeEnv::from_comp_node(comp_node).rocm_env().stream);
#else
#error "neither CUDA nor ROCm is enabled"
#endif
}

bool MegRayCommBuilder::find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm) {
std::unique_lock<std::mutex> lk(m_map_mtx);
auto it = m_megray_comms.find(hash);


+ 0
- 2
src/opr-mm/impl/zmq_rpc.cpp View File

@@ -1,6 +1,5 @@
#include "megbrain_build_config.h"

#if MGB_CUDA
#include "megbrain/opr/zmq_rpc.h"
#include "megbrain/common.h"
#include "megbrain/exception.h"
@@ -228,4 +227,3 @@ void ZmqRpcClient::request(message_t& request, message_t& reply) {
DISCARD_RETVAL(client->recv(reply));
add_socket(client);
}
#endif // MGB_CUDA

+ 4
- 0
src/opr-mm/include/megbrain/opr/megray_helper.h View File

@@ -12,7 +12,9 @@
#pragma once

#include <mutex>
#include <memory>

#include "megbrain/comp_node.h"
#include "megbrain/opr/group_manager.h"
#include "megray.h"

@@ -23,6 +25,8 @@ MegRay::DType get_megray_dtype(megdnn::DType);

MegRay::Backend get_megray_backend(const std::string& backend);

std::shared_ptr<MegRay::Context> get_megray_context(CompNode comp_node);

/*!
* gather MegRay unique ids and build communicator, use hash for deduplication
*/


+ 0
- 2
src/opr-mm/include/megbrain/opr/zmq_rpc.h View File

@@ -2,7 +2,6 @@

#include "megbrain_build_config.h"

#if MGB_CUDA
#include <unistd.h>
#include <cassert>
#include <iostream>
@@ -101,4 +100,3 @@ private:
std::vector<std::shared_ptr<zmq::socket_t>> m_own_sockets;
};
} // namespace ZmqRpc
#endif

Loading…
Cancel
Save