|
|
@@ -14,19 +14,35 @@ |
|
|
|
using namespace mgb; |
|
|
|
using namespace opr; |
|
|
|
|
|
|
|
bool MegRayCommunicatorBuilder::find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm) { |
|
|
|
std::unique_lock<std::mutex> lk(m_mtx); |
|
|
|
auto it = m_megray_comms.find(hash); |
|
|
|
if (it != m_megray_comms.end()) { |
|
|
|
comm = it->second; |
|
|
|
return true; |
|
|
|
} |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
void MegRayCommunicatorBuilder::emplace(uint64_t hash, |
|
|
|
std::shared_ptr<MegRay::Communicator> comm) { |
|
|
|
std::unique_lock<std::mutex> lk(m_mtx); |
|
|
|
m_megray_comms.emplace(hash, comm); |
|
|
|
} |
|
|
|
|
|
|
|
std::shared_ptr<MegRay::Communicator> MegRayCommunicatorBuilder::get_megray_comm( |
|
|
|
uint64_t hash, std::string key, uint32_t size, uint32_t rank, |
|
|
|
MegRay::Backend backend, |
|
|
|
std::shared_ptr<mgb::opr::GroupClient> group_client) { |
|
|
|
auto it = m_megray_comms.find(hash); |
|
|
|
if (it == m_megray_comms.end()) { |
|
|
|
auto comm = MegRay::get_communicator(size, rank, backend); |
|
|
|
std::shared_ptr<MegRay::Communicator> comm; |
|
|
|
if (!find(hash, comm)) { |
|
|
|
comm = MegRay::get_communicator(size, rank, backend); |
|
|
|
auto uid = comm->get_uid(); |
|
|
|
auto uids = group_client->gather_uid(uid, key, size, rank); |
|
|
|
comm->init(uids); |
|
|
|
m_megray_comms.emplace(hash, std::move(comm)); |
|
|
|
mgb_assert(comm->init(uids) == MegRay::Status::MEGRAY_OK); |
|
|
|
emplace(hash, comm); |
|
|
|
} |
|
|
|
return m_megray_comms[hash]; |
|
|
|
return comm; |
|
|
|
} |
|
|
|
|
|
|
|
MGB_TYPEINFO_OBJ_IMPL(MegRayCommunicatorBuilder); |
|
|
|