Browse Source

fix(mgb/opr-mm): fix user_data thread safety in CollectiveComm

GitOrigin-RevId: b6d6184e91
tags/v0.5.0
Megvii Engine Team 5 years ago
parent
commit
cde055e8f4
1 changed files with 9 additions and 6 deletions
  1. +9
    -6
      src/opr-mm/impl/collective_comm.cpp

+ 9
- 6
src/opr-mm/impl/collective_comm.cpp View File

@@ -458,13 +458,16 @@ void CollectiveComm::opr_register() {
auto hash = m_group_client->opr_register(m_key, m_nr_devices, m_rank,
reinterpret_cast<uintptr_t>(cuda_env.stream));

auto megray_comm_builder =
owner_graph()
->options()
.user_data
.get_user_data_or_create<MegRayCommunicatorBuilder>();
MegRayCommunicatorBuilder* builder;

m_megray_comm = megray_comm_builder->get_megray_comm(
{
static std::mutex user_data_mtx;
std::unique_lock<std::mutex> lk(user_data_mtx);
builder = owner_graph()->options().user_data
.get_user_data_or_create<MegRayCommunicatorBuilder>();
}

m_megray_comm = builder->get_megray_comm(
hash, m_key, m_nr_devices, m_rank,
get_megray_backend(m_backend), m_group_client);



Loading…
Cancel
Save