Browse Source

feat(mgb/opr-mm): add register info cache for multi-machine oprs

GitOrigin-RevId: d5ae3c5a7c
tags/v1.0.0-rc1
Megvii Engine Team Xinran Xu 4 years ago
parent
commit
27205461ae
4 changed files with 70 additions and 10 deletions
  1. +13
    -3
      src/opr-mm/impl/collective_comm.cpp
  2. +16
    -0
      src/opr-mm/impl/group_manager.cpp
  3. +25
    -7
      src/opr-mm/impl/io_remote.cpp
  4. +16
    -0
      src/opr-mm/include/megbrain/opr/group_manager.h

+ 13
- 3
src/opr-mm/impl/collective_comm.cpp View File

@@ -687,11 +687,21 @@ SymbolVarArray CollectiveComm::make(
void CollectiveComm::opr_register() {
if (m_init)
return;

auto&& comp_node = output(0)->comp_node();
bool use_cache = output(0)->owner_graph()->options().imperative_proxy_graph;
struct GroupManager::RegisterInfo reg_info;

auto reg_info = m_group_client->opr_register(
m_key, m_nr_devices, m_is_root, m_rank,
comp_node.get_uid());
if (use_cache and RegInfoCache::has_info(m_key)) {
reg_info = RegInfoCache::get_info(m_key);
} else {
reg_info = m_group_client->opr_register(
m_key, m_nr_devices, m_is_root, m_rank,
comp_node.get_uid());
if (use_cache) {
RegInfoCache::set_info(m_key, reg_info);
}
}

m_rank = reg_info.rank;
m_root = reg_info.root_rank;


+ 16
- 0
src/opr-mm/impl/group_manager.cpp View File

@@ -205,4 +205,20 @@ uint32_t GroupManager::group_barrier(uint32_t size, uint32_t rank) {
return m_barrier_size;
}

void RegInfoCache::set_info(const std::string& key,
const GroupManager::RegisterInfo& info) {
std::unique_lock<std::mutex> lock(RegInfoCache::mtx);
RegInfoCache::key2info[key] = info;
}

bool RegInfoCache::has_info(const std::string& key) {
std::unique_lock<std::mutex> lock(RegInfoCache::mtx);
return RegInfoCache::key2info.find(key) != RegInfoCache::key2info.end();
}

GroupManager::RegisterInfo RegInfoCache::get_info(const std::string& key) {
std::unique_lock<std::mutex> lock(RegInfoCache::mtx);
return RegInfoCache::key2info[key];
}

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 25
- 7
src/opr-mm/impl/io_remote.cpp View File

@@ -53,10 +53,19 @@ SymbolVar RemoteSend::make(const std::string& key, SymbolVar var,
void RemoteSend::scn_do_execute() {
if (!m_init) {
auto&& comp_node = output(0)->comp_node();
bool use_cache = output(0)->owner_graph()->options().imperative_proxy_graph;
struct GroupManager::RegisterInfo reg_info;

// rank 0 for RemoteSend
auto reg_info = m_group_client->opr_register(m_key, 2, 0, false,
comp_node.get_uid());
if (use_cache and RegInfoCache::has_info(m_key)) {
reg_info = RegInfoCache::get_info(m_key);
} else {
// rank 0 for RemoteSend
reg_info = m_group_client->opr_register(m_key, 2, 0, false,
comp_node.get_uid());
if (use_cache) {
RegInfoCache::set_info(m_key, reg_info);
}
}

m_megray_comm = MegRayCommBuilder::get_megray_comm(
reg_info.hash, m_key, 2, 0, MegRay::MEGRAY_NCCL, m_group_client);
@@ -153,11 +162,20 @@ SymbolVar RemoteRecv::make(const std::string& key, cg::ComputingGraph& graph,
void RemoteRecv::scn_do_execute() {
if (!m_init) {
auto&& comp_node = output(0)->comp_node();
bool use_cache = output(0)->owner_graph()->options().imperative_proxy_graph;
struct GroupManager::RegisterInfo reg_info;

// rank 1 for RemoteRecv
auto reg_info = m_group_client->opr_register(
m_key, 2, false, 1,
comp_node.get_uid());
if (use_cache and RegInfoCache::has_info(m_key)) {
reg_info = RegInfoCache::get_info(m_key);
} else {
// rank 1 for RemoteRecv
reg_info = m_group_client->opr_register(
m_key, 2, false, 1,
comp_node.get_uid());
if (use_cache) {
RegInfoCache::set_info(m_key, reg_info);
}
}

m_megray_comm = MegRayCommBuilder::get_megray_comm(
reg_info.hash, m_key, 2, 1, MegRay::MEGRAY_NCCL, m_group_client);


+ 16
- 0
src/opr-mm/include/megbrain/opr/group_manager.h View File

@@ -145,6 +145,22 @@ class GroupClient {
virtual uint32_t group_barrier(uint32_t size, uint32_t rank) = 0;
};

/*!
* Cache RegisterInfo returned from GroupManager. This feature is only enabled
* in imperative runtime mode, so that multi-machine operators do not have to
* call opr_register repeatedly in each iter
*/
namespace RegInfoCache {

static std::mutex mtx;
static std::unordered_map<std::string, GroupManager::RegisterInfo> key2info;

void set_info(const std::string& key, const GroupManager::RegisterInfo& info);
bool has_info(const std::string& key);
GroupManager::RegisterInfo get_info(const std::string& key);

} // namespace RegInfoCache

} // namespace opr
} // namespace mgb



Loading…
Cancel
Save