GitOrigin-RevId: d5ae3c5a7c
tags/v1.0.0-rc1
@@ -687,11 +687,21 @@ SymbolVarArray CollectiveComm::make( | |||||
void CollectiveComm::opr_register() { | void CollectiveComm::opr_register() { | ||||
if (m_init) | if (m_init) | ||||
return; | return; | ||||
auto&& comp_node = output(0)->comp_node(); | auto&& comp_node = output(0)->comp_node(); | ||||
bool use_cache = output(0)->owner_graph()->options().imperative_proxy_graph; | |||||
struct GroupManager::RegisterInfo reg_info; | |||||
auto reg_info = m_group_client->opr_register( | |||||
m_key, m_nr_devices, m_is_root, m_rank, | |||||
comp_node.get_uid()); | |||||
if (use_cache and RegInfoCache::has_info(m_key)) { | |||||
reg_info = RegInfoCache::get_info(m_key); | |||||
} else { | |||||
reg_info = m_group_client->opr_register( | |||||
m_key, m_nr_devices, m_is_root, m_rank, | |||||
comp_node.get_uid()); | |||||
if (use_cache) { | |||||
RegInfoCache::set_info(m_key, reg_info); | |||||
} | |||||
} | |||||
m_rank = reg_info.rank; | m_rank = reg_info.rank; | ||||
m_root = reg_info.root_rank; | m_root = reg_info.root_rank; | ||||
@@ -205,4 +205,20 @@ uint32_t GroupManager::group_barrier(uint32_t size, uint32_t rank) { | |||||
return m_barrier_size; | return m_barrier_size; | ||||
} | } | ||||
void RegInfoCache::set_info(const std::string& key, | |||||
const GroupManager::RegisterInfo& info) { | |||||
std::unique_lock<std::mutex> lock(RegInfoCache::mtx); | |||||
RegInfoCache::key2info[key] = info; | |||||
} | |||||
bool RegInfoCache::has_info(const std::string& key) { | |||||
std::unique_lock<std::mutex> lock(RegInfoCache::mtx); | |||||
return RegInfoCache::key2info.find(key) != RegInfoCache::key2info.end(); | |||||
} | |||||
GroupManager::RegisterInfo RegInfoCache::get_info(const std::string& key) { | |||||
std::unique_lock<std::mutex> lock(RegInfoCache::mtx); | |||||
return RegInfoCache::key2info[key]; | |||||
} | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -53,10 +53,19 @@ SymbolVar RemoteSend::make(const std::string& key, SymbolVar var, | |||||
void RemoteSend::scn_do_execute() { | void RemoteSend::scn_do_execute() { | ||||
if (!m_init) { | if (!m_init) { | ||||
auto&& comp_node = output(0)->comp_node(); | auto&& comp_node = output(0)->comp_node(); | ||||
bool use_cache = output(0)->owner_graph()->options().imperative_proxy_graph; | |||||
struct GroupManager::RegisterInfo reg_info; | |||||
// rank 0 for RemoteSend | |||||
auto reg_info = m_group_client->opr_register(m_key, 2, 0, false, | |||||
comp_node.get_uid()); | |||||
if (use_cache and RegInfoCache::has_info(m_key)) { | |||||
reg_info = RegInfoCache::get_info(m_key); | |||||
} else { | |||||
// rank 0 for RemoteSend | |||||
reg_info = m_group_client->opr_register(m_key, 2, 0, false, | |||||
comp_node.get_uid()); | |||||
if (use_cache) { | |||||
RegInfoCache::set_info(m_key, reg_info); | |||||
} | |||||
} | |||||
m_megray_comm = MegRayCommBuilder::get_megray_comm( | m_megray_comm = MegRayCommBuilder::get_megray_comm( | ||||
reg_info.hash, m_key, 2, 0, MegRay::MEGRAY_NCCL, m_group_client); | reg_info.hash, m_key, 2, 0, MegRay::MEGRAY_NCCL, m_group_client); | ||||
@@ -153,11 +162,20 @@ SymbolVar RemoteRecv::make(const std::string& key, cg::ComputingGraph& graph, | |||||
void RemoteRecv::scn_do_execute() { | void RemoteRecv::scn_do_execute() { | ||||
if (!m_init) { | if (!m_init) { | ||||
auto&& comp_node = output(0)->comp_node(); | auto&& comp_node = output(0)->comp_node(); | ||||
bool use_cache = output(0)->owner_graph()->options().imperative_proxy_graph; | |||||
struct GroupManager::RegisterInfo reg_info; | |||||
// rank 1 for RemoteRecv | |||||
auto reg_info = m_group_client->opr_register( | |||||
m_key, 2, false, 1, | |||||
comp_node.get_uid()); | |||||
if (use_cache and RegInfoCache::has_info(m_key)) { | |||||
reg_info = RegInfoCache::get_info(m_key); | |||||
} else { | |||||
// rank 1 for RemoteRecv | |||||
reg_info = m_group_client->opr_register( | |||||
m_key, 2, false, 1, | |||||
comp_node.get_uid()); | |||||
if (use_cache) { | |||||
RegInfoCache::set_info(m_key, reg_info); | |||||
} | |||||
} | |||||
m_megray_comm = MegRayCommBuilder::get_megray_comm( | m_megray_comm = MegRayCommBuilder::get_megray_comm( | ||||
reg_info.hash, m_key, 2, 1, MegRay::MEGRAY_NCCL, m_group_client); | reg_info.hash, m_key, 2, 1, MegRay::MEGRAY_NCCL, m_group_client); | ||||
@@ -145,6 +145,22 @@ class GroupClient { | |||||
virtual uint32_t group_barrier(uint32_t size, uint32_t rank) = 0; | virtual uint32_t group_barrier(uint32_t size, uint32_t rank) = 0; | ||||
}; | }; | ||||
/*! | |||||
* Cache RegisterInfo returned from GroupManager. This feature is only enabled | |||||
* in imperative runtime mode, so that multi-machine operators do not have to | |||||
* call opr_register repeatedly in each iter | |||||
*/ | |||||
namespace RegInfoCache { | |||||
static std::mutex mtx; | |||||
static std::unordered_map<std::string, GroupManager::RegisterInfo> key2info; | |||||
void set_info(const std::string& key, const GroupManager::RegisterInfo& info); | |||||
bool has_info(const std::string& key); | |||||
GroupManager::RegisterInfo get_info(const std::string& key); | |||||
} // namespace RegInfoCache | |||||
} // namespace opr | } // namespace opr | ||||
} // namespace mgb | } // namespace mgb | ||||