From cde055e8f42499faf62678c6f386cf284e7a6280 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Sun, 24 May 2020 19:59:36 +0800 Subject: [PATCH] fix(mgb/opr-mm): fix user_data thread safety in CollectiveComm GitOrigin-RevId: b6d6184e91b254119a96c892a3c5a68b024729a2 --- src/opr-mm/impl/collective_comm.cpp | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/opr-mm/impl/collective_comm.cpp b/src/opr-mm/impl/collective_comm.cpp index 5f6166e1..d80df9b7 100644 --- a/src/opr-mm/impl/collective_comm.cpp +++ b/src/opr-mm/impl/collective_comm.cpp @@ -458,13 +458,16 @@ void CollectiveComm::opr_register() { auto hash = m_group_client->opr_register(m_key, m_nr_devices, m_rank, reinterpret_cast(cuda_env.stream)); - auto megray_comm_builder = - owner_graph() - ->options() - .user_data - .get_user_data_or_create(); + MegRayCommunicatorBuilder* builder; - m_megray_comm = megray_comm_builder->get_megray_comm( + { + static std::mutex user_data_mtx; + std::unique_lock lk(user_data_mtx); + builder = owner_graph()->options().user_data + .get_user_data_or_create(); + } + + m_megray_comm = builder->get_megray_comm( hash, m_key, m_nr_devices, m_rank, get_megray_backend(m_backend), m_group_client);