GitOrigin-RevId: 4840100d07
release-1.4
@@ -339,8 +339,8 @@ if(MGE_BUILD_IMPERATIVE_RT) | |||||
set(CMAKE_CXX_STANDARD 17) | set(CMAKE_CXX_STANDARD 17) | ||||
endif() | endif() | ||||
if(NOT MGE_WITH_CUDA) | |||||
message(STATUS "Disable distributed support, as CUDA is not enabled.") | |||||
if(NOT ${MGE_WITH_CUDA} AND NOT ${MGE_WITH_ROCM}) | |||||
message(STATUS "Disable distributed support, as both CUDA and ROCm are disabled.") | |||||
set(MGE_WITH_DISTRIBUTED OFF) | set(MGE_WITH_DISTRIBUTED OFF) | ||||
endif() | endif() | ||||
@@ -903,6 +903,8 @@ if(MGE_WITH_JIT_MLIR OR MGE_BUILD_IMPERATIVE_RT) | |||||
endif() | endif() | ||||
if(MGE_WITH_DISTRIBUTED) | if(MGE_WITH_DISTRIBUTED) | ||||
set(MEGRAY_WITH_NCCL ${MGE_WITH_CUDA} CACHE BOOL "Override MegRay option" FORCE) | |||||
set(MEGRAY_WITH_RCCL ${MGE_WITH_ROCM} CACHE BOOL "Override MegRay option" FORCE) | |||||
add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/MegRay) | add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/MegRay) | ||||
endif() | endif() | ||||
@@ -79,6 +79,8 @@ namespace { | |||||
if (g_unspec_locator_type == DT::UNSPEC) { | if (g_unspec_locator_type == DT::UNSPEC) { | ||||
if (CudaCompNode::available()) { | if (CudaCompNode::available()) { | ||||
g_unspec_locator_type = DT::CUDA; | g_unspec_locator_type = DT::CUDA; | ||||
} else if (ROCmCompNode::available()) { | |||||
g_unspec_locator_type = DT::ROCM; | |||||
} else { | } else { | ||||
g_unspec_locator_type = DT::CPU; | g_unspec_locator_type = DT::CPU; | ||||
} | } | ||||
@@ -217,6 +217,11 @@ public: | |||||
Locator locator() override { return m_locator; } | Locator locator() override { return m_locator; } | ||||
Locator locator_logical() override { return m_locator_logical; } | Locator locator_logical() override { return m_locator_logical; } | ||||
uint64_t get_uid() override { return m_uid; } | |||||
private: | |||||
uint64_t m_uid; | |||||
}; | }; | ||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(ROCmCompNode::CompNodeImpl); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(ROCmCompNode::CompNodeImpl); | ||||
@@ -278,6 +283,17 @@ void ROCmCompNodeImpl::init(const Locator& locator, | |||||
m_locator_logical = locator_logical; | m_locator_logical = locator_logical; | ||||
m_initialized = true; | m_initialized = true; | ||||
#if defined(__linux__) || defined(TARGET_OS_MAC) | |||||
FILE *fp; | |||||
fp = fopen("/dev/urandom", "r"); | |||||
mgb_assert(fread(&m_uid, sizeof(m_uid), 1, fp) == 1); | |||||
fclose(fp); | |||||
#else | |||||
m_uid = std::chrono::duration_cast<std::chrono::nanoseconds>( | |||||
std::chrono::system_clock::now().time_since_epoch() | |||||
).count(); | |||||
#endif | |||||
auto on_succ = [this](hipStream_t stream) { | auto on_succ = [this](hipStream_t stream) { | ||||
auto locator = m_locator; | auto locator = m_locator; | ||||
log_comp_node_created(locator, m_locator_logical); | log_comp_node_created(locator, m_locator_logical); | ||||
@@ -47,9 +47,6 @@ const char* get_param_name(CollectiveComm::Param param) { | |||||
} | } | ||||
} | } | ||||
cudaStream_t get_stream(VarNode* var) { | |||||
return CompNodeEnv::from_comp_node(var->comp_node()).cuda_env().stream; | |||||
} | |||||
} // anonymous namespace | } // anonymous namespace | ||||
/* ================= ModeTrait ================= */ | /* ================= ModeTrait ================= */ | ||||
@@ -519,8 +516,6 @@ CollectiveComm::CollectiveComm( | |||||
// add input | // add input | ||||
mgb_assert(inputs.size() <= 1, "one or zero input expected, got %zu", inputs.size()); | mgb_assert(inputs.size() <= 1, "one or zero input expected, got %zu", inputs.size()); | ||||
if (inputs.size() > 0) { | if (inputs.size() > 0) { | ||||
mgb_assert(inputs[0]->comp_node().device_type() == CompNode::DeviceType::CUDA, | |||||
"CollectiveComm currectly only supports CUDA"); | |||||
add_input({inputs[0]}); | add_input({inputs[0]}); | ||||
} | } | ||||
@@ -531,8 +526,6 @@ CollectiveComm::CollectiveComm( | |||||
const auto& cns = config.comp_node(); | const auto& cns = config.comp_node(); | ||||
mgb_assert(cns.size() <= 1, "one or zero comp node expected, got %zu", cns.size()); | mgb_assert(cns.size() <= 1, "one or zero comp node expected, got %zu", cns.size()); | ||||
if (cns.size() > 0) { | if (cns.size() > 0) { | ||||
mgb_assert(cns[0].device_type() == CompNode::DeviceType::CUDA, | |||||
"CollectiveComm currectly only supports CUDA"); | |||||
output(0)->comp_node(cns[0]); | output(0)->comp_node(cns[0]); | ||||
} else { | } else { | ||||
output(0)->comp_node(inputs[0]->comp_node()); | output(0)->comp_node(inputs[0]->comp_node()); | ||||
@@ -609,7 +602,7 @@ void CollectiveComm::opr_register() { | |||||
reg_info.hash, m_key, m_nr_devices, m_rank, | reg_info.hash, m_key, m_nr_devices, m_rank, | ||||
get_megray_backend(m_backend), m_group_client); | get_megray_backend(m_backend), m_group_client); | ||||
m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0))); | |||||
m_megray_ctx = get_megray_context(output(0)->comp_node()); | |||||
m_init = true; | m_init = true; | ||||
} | } | ||||
@@ -18,10 +18,6 @@ | |||||
using namespace mgb; | using namespace mgb; | ||||
using namespace opr; | using namespace opr; | ||||
cudaStream_t get_stream(VarNode* var) { | |||||
return CompNodeEnv::from_comp_node(var->comp_node()).cuda_env().stream; | |||||
} | |||||
/* ===================== RemoteSend ===================== */ | /* ===================== RemoteSend ===================== */ | ||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteSend); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteSend); | ||||
@@ -70,7 +66,7 @@ void RemoteSend::scn_do_execute() { | |||||
m_megray_comm = MegRayCommBuilder::get_megray_comm( | m_megray_comm = MegRayCommBuilder::get_megray_comm( | ||||
reg_info.hash, m_key, 2, 0, MegRay::MEGRAY_NCCL, m_group_client); | reg_info.hash, m_key, 2, 0, MegRay::MEGRAY_NCCL, m_group_client); | ||||
m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0))); | |||||
m_megray_ctx = get_megray_context(output(0)->comp_node()); | |||||
m_init = true; | m_init = true; | ||||
} | } | ||||
@@ -207,7 +203,7 @@ void RemoteRecv::scn_do_execute() { | |||||
m_megray_comm = MegRayCommBuilder::get_megray_comm( | m_megray_comm = MegRayCommBuilder::get_megray_comm( | ||||
reg_info.hash, m_key, 2, 1, MegRay::MEGRAY_NCCL, m_group_client); | reg_info.hash, m_key, 2, 1, MegRay::MEGRAY_NCCL, m_group_client); | ||||
m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0))); | |||||
m_megray_ctx = get_megray_context(output(0)->comp_node()); | |||||
m_init = true; | m_init = true; | ||||
} | } | ||||
@@ -10,6 +10,7 @@ | |||||
*/ | */ | ||||
#include "megbrain/opr/megray_helper.h" | #include "megbrain/opr/megray_helper.h" | ||||
#include "megbrain/comp_node_env.h" | |||||
using namespace mgb; | using namespace mgb; | ||||
using namespace opr; | using namespace opr; | ||||
@@ -34,6 +35,8 @@ MegRay::DType mgb::opr::get_megray_dtype(megdnn::DType dtype) { | |||||
MegRay::Backend mgb::opr::get_megray_backend(const std::string& backend) { | MegRay::Backend mgb::opr::get_megray_backend(const std::string& backend) { | ||||
if (backend == "nccl") { | if (backend == "nccl") { | ||||
return MegRay::MEGRAY_NCCL; | return MegRay::MEGRAY_NCCL; | ||||
} else if (backend == "rccl") { | |||||
return MegRay::MEGRAY_RCCL; | |||||
} else if (backend == "ucx") { | } else if (backend == "ucx") { | ||||
return MegRay::MEGRAY_UCX; | return MegRay::MEGRAY_UCX; | ||||
} else { | } else { | ||||
@@ -41,6 +44,16 @@ MegRay::Backend mgb::opr::get_megray_backend(const std::string& backend) { | |||||
} | } | ||||
} | } | ||||
std::shared_ptr<MegRay::Context> mgb::opr::get_megray_context(CompNode comp_node){ | |||||
#if MGB_CUDA | |||||
return MegRay::CudaContext::make(CompNodeEnv::from_comp_node(comp_node).cuda_env().stream); | |||||
#elif MGB_ROCM | |||||
return MegRay::HipContext::make(CompNodeEnv::from_comp_node(comp_node).rocm_env().stream); | |||||
#else | |||||
#error "neither CUDA nor ROCm is enabled" | |||||
#endif | |||||
} | |||||
bool MegRayCommBuilder::find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm) { | bool MegRayCommBuilder::find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm) { | ||||
std::unique_lock<std::mutex> lk(m_map_mtx); | std::unique_lock<std::mutex> lk(m_map_mtx); | ||||
auto it = m_megray_comms.find(hash); | auto it = m_megray_comms.find(hash); | ||||
@@ -1,6 +1,5 @@ | |||||
#include "megbrain_build_config.h" | #include "megbrain_build_config.h" | ||||
#if MGB_CUDA | |||||
#include "megbrain/opr/zmq_rpc.h" | #include "megbrain/opr/zmq_rpc.h" | ||||
#include "megbrain/common.h" | #include "megbrain/common.h" | ||||
#include "megbrain/exception.h" | #include "megbrain/exception.h" | ||||
@@ -228,4 +227,3 @@ void ZmqRpcClient::request(message_t& request, message_t& reply) { | |||||
DISCARD_RETVAL(client->recv(reply)); | DISCARD_RETVAL(client->recv(reply)); | ||||
add_socket(client); | add_socket(client); | ||||
} | } | ||||
#endif // MGB_CUDA |
@@ -12,7 +12,9 @@ | |||||
#pragma once | #pragma once | ||||
#include <mutex> | #include <mutex> | ||||
#include <memory> | |||||
#include "megbrain/comp_node.h" | |||||
#include "megbrain/opr/group_manager.h" | #include "megbrain/opr/group_manager.h" | ||||
#include "megray.h" | #include "megray.h" | ||||
@@ -23,6 +25,8 @@ MegRay::DType get_megray_dtype(megdnn::DType); | |||||
MegRay::Backend get_megray_backend(const std::string& backend); | MegRay::Backend get_megray_backend(const std::string& backend); | ||||
std::shared_ptr<MegRay::Context> get_megray_context(CompNode comp_node); | |||||
/*! | /*! | ||||
* gather MegRay unique ids and build communicator, use hash for deduplication | * gather MegRay unique ids and build communicator, use hash for deduplication | ||||
*/ | */ | ||||
@@ -2,7 +2,6 @@ | |||||
#include "megbrain_build_config.h" | #include "megbrain_build_config.h" | ||||
#if MGB_CUDA | |||||
#include <unistd.h> | #include <unistd.h> | ||||
#include <cassert> | #include <cassert> | ||||
#include <iostream> | #include <iostream> | ||||
@@ -101,4 +100,3 @@ private: | |||||
std::vector<std::shared_ptr<zmq::socket_t>> m_own_sockets; | std::vector<std::shared_ptr<zmq::socket_t>> m_own_sockets; | ||||
}; | }; | ||||
} // namespace ZmqRpc | } // namespace ZmqRpc | ||||
#endif |