|
|
@@ -53,10 +53,19 @@ SymbolVar RemoteSend::make(const std::string& key, SymbolVar var, |
|
|
|
void RemoteSend::scn_do_execute() { |
|
|
|
if (!m_init) { |
|
|
|
auto&& comp_node = output(0)->comp_node(); |
|
|
|
bool use_cache = output(0)->owner_graph()->options().imperative_proxy_graph; |
|
|
|
struct GroupManager::RegisterInfo reg_info; |
|
|
|
|
|
|
|
// rank 0 for RemoteSend |
|
|
|
auto reg_info = m_group_client->opr_register(m_key, 2, 0, false, |
|
|
|
comp_node.get_uid()); |
|
|
|
if (use_cache and RegInfoCache::has_info(m_key)) { |
|
|
|
reg_info = RegInfoCache::get_info(m_key); |
|
|
|
} else { |
|
|
|
// rank 0 for RemoteSend |
|
|
|
reg_info = m_group_client->opr_register(m_key, 2, 0, false, |
|
|
|
comp_node.get_uid()); |
|
|
|
if (use_cache) { |
|
|
|
RegInfoCache::set_info(m_key, reg_info); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
m_megray_comm = MegRayCommBuilder::get_megray_comm( |
|
|
|
reg_info.hash, m_key, 2, 0, MegRay::MEGRAY_NCCL, m_group_client); |
|
|
@@ -153,11 +162,20 @@ SymbolVar RemoteRecv::make(const std::string& key, cg::ComputingGraph& graph, |
|
|
|
void RemoteRecv::scn_do_execute() { |
|
|
|
if (!m_init) { |
|
|
|
auto&& comp_node = output(0)->comp_node(); |
|
|
|
bool use_cache = output(0)->owner_graph()->options().imperative_proxy_graph; |
|
|
|
struct GroupManager::RegisterInfo reg_info; |
|
|
|
|
|
|
|
// rank 1 for RemoteRecv |
|
|
|
auto reg_info = m_group_client->opr_register( |
|
|
|
m_key, 2, false, 1, |
|
|
|
comp_node.get_uid()); |
|
|
|
if (use_cache and RegInfoCache::has_info(m_key)) { |
|
|
|
reg_info = RegInfoCache::get_info(m_key); |
|
|
|
} else { |
|
|
|
// rank 1 for RemoteRecv |
|
|
|
reg_info = m_group_client->opr_register( |
|
|
|
m_key, 2, false, 1, |
|
|
|
comp_node.get_uid()); |
|
|
|
if (use_cache) { |
|
|
|
RegInfoCache::set_info(m_key, reg_info); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
m_megray_comm = MegRayCommBuilder::get_megray_comm( |
|
|
|
reg_info.hash, m_key, 2, 1, MegRay::MEGRAY_NCCL, m_group_client); |
|
|
|